def make_env_with_constraints(env_id, novelty_family=None, inject=False):
    # print ()
    env2 = gym.make(env_id)
    env2.unbreakable_items.add(
        'crafting_table'
    )  # Make crafting table unbreakable for easy solving of task.
    env2 = LimitActions(env2,
                        {'Forward', 'Left', 'Right', 'Break', 'Craft_bow'
                         })  # limit actions for easy training
    env2 = LidarInFront(
        env2)  # generate the observation space using LIDAR sensors
    # print(env.unbreakable_items)
    env2.reward_done = 1000
    env2.reward_intermediate = 50
    if inject:
        env2 = inject_novelty(env2, novelty_family[0], novelty_family[1],
                              novelty_family[2], novelty_family[3])
        if novelty_family[0] == 'breakincrease':
            print(
                "Break increase novelty injected::: self.env.itemtobreakmore = {}"
                .format(env2.itemtobreakmore))
        if novelty_family[0] == 'remapaction':
            print(
                "Action remap novelty check:: self.env.limited_actions_id = {}"
                .format(env2.limited_actions_id))
    # if novelty_family[0]
    check_env(env2, warn=True)  # check the environment
    return env2
Exemple #2
0
def main():
    dm, y_oracle = init_dm(CONFIG)
    print(dm)

    cnn = ConvNet(n_output=2)
    env = QueryEnv(dm, cnn, CONFIG)
    env = MonitorWrapper(env, autolog=True)
    check_env(env, warn=True)

    model = load_dqn_model('data/rl_rps.pth')

    n_queries = 10
    n_epochs_cnn = 8
    for k in range(n_queries):
        print(dm)
        done = False
        obs = env.reset()
        while not done:
            action, _states = model.predict(obs)
            obs, reward, done, info = env.step(action)

        query_indicies = env.get_query_indicies()
        dm.label_samples(query_indicies, y_oracle[query_indicies])
        y_oracle = np.delete(y_oracle, query_indicies, axis=0)

        cnn.fit(*dm.train.get_xy(), n_epochs_cnn, batch_size=32)
        cnn.evaluate(*dm.eval.get_xy())

    cnn.evaluate(*dm.test.get_xy())
def test_env_init(file_array, stock_number, action_space):
    from stable_baselines.common.env_checker import check_env
    env = StockTradingEnvV1(df=None, file_array=file_array)
    check_env(env)
    assert env.action_space == spaces.MultiDiscrete(action_space)
    assert env.stock_number == stock_number
    assert env.reset().shape == env.observation_space.shape
    assert env.obs_column_number == 1 * stock_number
Exemple #4
0
def main():
    """
    Tests gym-compatibility of envs
    """
    env = PokeEnv()
    check_env(env)

    print('Test complete.')
    return True
    def test_stable_baselines_check(self):
        '''Use the environment checker from stable baselines to test 
        the environment. This checks that the environment follows the 
        Gym API. It also optionally checks that the environment is 
        compatible with Stable-Baselines repository.
        
        '''

        check_env(self.env, warn=True)
def test_id_env():
    env = IdEnv(max_steps=6)
    check_env(env)
    assert env.reset()[0] == 0
    collector = EnvDataCollector(env)

    episode(collector)
    collector.flush()
    assert len(collector.raw_data[0]) == 7
Exemple #7
0
def check_environment(env_name):
    from gym_mimic_envs.monitor import Monitor as EnvMonitor
    from stable_baselines.common.env_checker import check_env
    env = gym.make(env_name)
    log('Checking custom environment')
    check_env(env)
    env = EnvMonitor(env)
    log('Checking custom env in custom monitor wrapper')
    check_env(env)
    exit(33)
def test_env_and_pre_trained_agent():
    panther_agent_filepath = '/data/agents/models/test_20200325_184254/PPO2_20200325_184254_panther/' #This model is in minio or the package.
    # env = plark_env.PlarkEnv(driving_agent='pelican',panther_agent_filepath=panther_agent_filepath,config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json')
    env = gym.make('plark-env-v0', driving_agent='pelican',panther_agent_filepath=panther_agent_filepath)       
    check_env(env)
    
    training_steps = 500
    model = PPO2('CnnPolicy', env)
    model.learn(training_steps)

    n_eval_episodes = 2
    mean_reward, n_steps, victories = helper.evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False, render=False, callback=None, reward_threshold=None, return_episode_rewards=False)
Exemple #9
0
def test_non_default_spaces(new_obs_space):
    env = gym.make('BreakoutNoFrameskip-v4')
    env.observation_space = new_obs_space
    # Patch methods to avoid errors
    env.reset = new_obs_space.sample

    def patched_step(_action):
        return new_obs_space.sample(), 0.0, False, {}

    env.step = patched_step
    with pytest.warns(UserWarning):
        check_env(env)
Exemple #10
0
 def __init__(self, model: abc.ABCMeta):
     # Define new environment
     self.env = RaspEnv()
     # Check if environment is ok
     check_env(self.env)
     # Define empty model
     self.model = model
     # Configure model hyperparameters
     self.config = {
         'learning_starts': 32,
         'target_network_update_freq': 100,
         'learning_rate': 0.001
     }
Exemple #11
0
def check_reset_assert_error(env, new_reset_return):
    """
    Helper to check that the error is caught.
    :param env: (gym.Env)
    :param new_reset_return: (Any)
    """
    def wrong_reset():
        return new_reset_return

    # Patch the reset method with a wrong one
    env.reset = wrong_reset
    with pytest.raises(AssertionError):
        check_env(env)
Exemple #12
0
def check_step_assert_error(env, new_step_return=()):
    """
    Helper to check that the error is caught.
    :param env: (gym.Env)
    :param new_step_return: (tuple)
    """
    def wrong_step(_action):
        return new_step_return

    # Patch the step method with a wrong one
    env.step = wrong_step
    with pytest.raises(AssertionError):
        check_env(env)
def test_env_and_agent():
    # env = plark_env.PlarkEnv(config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json')
    env = gym.make('plark-env-v0')
    check_env(env)
    
    training_steps = 10
    model = PPO2('CnnPolicy', env)
    model.learn(training_steps)

    n_eval_episodes = 2
    mean_reward, n_steps, victories = helper.evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False, render=False, callback=None, reward_threshold=None, return_episode_rewards=False)
    
    assert "PlarkEnv" in str(type(env))
    assert "PPO2" in str(type(model))
def test_normalize_false():
    from stable_baselines.common.env_checker import check_env
    env = StockTradingEnvV1(debug=True, normalize_observation=False)
    check_env(env)

    assert env.observation_space.shape == (14, )
    obs = env.reset()
    stacked_obs_part = np.reshape(
        obs[:env.obs_column_number * env.observation_frame],
        (-1, env.stock_number))
    index_end = env.current_step + env.start_point
    index_start = env.current_step + env.start_point - env.observation_frame + 1
    ref_obs = env.df[
        env.observation_column_name_array].loc[index_start:index_end].values
    assert ((ref_obs == stacked_obs_part).all()).all()
def test_illegal_move_limit_driving_pelican():
    panther_agent_filepath = '/Components/plark-game/plark_game/agents/basic/pantherAgent_move_north.py' #This model is in minio or the package.
    # env = plark_env.PlarkEnv(driving_agent='pelican',panther_agent_filepath=panther_agent_filepath, panther_agent_name='Panther_Agent_Move_North',config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json')
    env = gym.make('plark-env-v0', driving_agent='pelican',panther_agent_filepath=panther_agent_filepath, panther_agent_name='Panther_Agent_Move_North')    
    check_env(env)
    env.reset()
    # Repeat illegal move (illegal move limit - 1) times
    for i in range(9):
        obs, reward, done, info = env.step(0)
        assert info['illegal_move'] == True, "Should have made illegal move"
        assert info['turn'] == 0, "Should still be on first turn"
    # Next illegal move should end turn    
    obs, reward, done, info = env.step(0)
    assert info['illegal_move'] == True
    assert info['turn'] == 1
def test_illegal_move_limit_driving_panther():
    pelican_agent_filepath = '/data/agents/models/test_20200325_184254/PPO2_20200325_184254_pelican/' #This model is in minio or the package.
    # env = plark_env.PlarkEnv(driving_agent='panther',pelican_agent_filepath=pelican_agent_filepath,config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json')
    env = gym.make('plark-env-v0', driving_agent='panther',pelican_agent_filepath=pelican_agent_filepath)  
    check_env(env)
    env.reset()
    # Repeat illegal move (illegal move limit - 1) times
    for i in range(9):
        obs, reward, done, info = env.step(3)
        assert info['illegal_move'] == True, "Should have made illegal move"
        assert info['turn'] == 0, "Should still be on first turn"
    # Next illegal move should end turn    
    obs, reward, done, info = env.step(0)
    assert info['illegal_move'] == False
    assert info['turn'] == 1
Exemple #17
0
def test_env(env_id):
    """
    Check that environmnent integrated in Gym pass the test.

    :param env_id: (str)
    """
    env = gym.make(env_id)
    with pytest.warns(None) as record:
        check_env(env)

    # Pendulum-v0 will produce a warning because the action space is
    # in [-2, 2] and not [-1, 1]
    if env_id == 'Pendulum-v0':
        assert len(record) == 1
    else:
        # The other environments must pass without warning
        assert len(record) == 0
Exemple #18
0
def test_high_dimension_action_space():
    """
    Test for continuous action space
    with more than one action.
    """
    env = gym.make('Pendulum-v0')
    # Patch the action space
    env.action_space = spaces.Box(low=-1,
                                  high=1,
                                  shape=(20, ),
                                  dtype=np.float32)

    # Patch to avoid error
    def patched_step(_action):
        return env.observation_space.sample(), 0.0, False, {}

    env.step = patched_step
    check_env(env)
def baseline_test():
    from stable_baselines import PPO2
    from stable_baselines.common.policies import MlpPolicy
    from stable_baselines.common import env_checker


    gym_env_fn = lambda: Gym_procgen_continuous(env_name = 'fruitbot')

    env_checker.check_env(gym_env_fn())
    print('Success! : environment is compatible with stable baselines')

    env = make_vec_env(gym_env_fn, n_envs = 64)

    

    model = PPO2(MlpPolicy, env, verbose=1)
    model.learn(total_timesteps=100000)

    print('Success! : stable baselines trained on the vectorized environment')
Exemple #20
0
def main(video=True, title=False):
    env = gym.make("futbol-v1")
    check_env(env, warn=True)
    file_name = 'left-ppo2-lstm-2v2-5e3'
    model = PPO2.load("zoo/2v2/" + file_name)
    # file_name = 'ppo2-futbol-10M_best_model'
    # model = PPO2.load("supplement/" + file_name)
    prefix = file_name
    record_length = 300
    if video:
        if title:
            record_video_with_title('futbol-v1', model, prefix=prefix)
            show_video('videos/' + prefix + '.mp4')
        else:
            record_video('futbol-v1', model, video_length=record_length, prefix=prefix, lstm=True)
            show_video('videos/' + prefix + '-step-0-to-step-' + str(record_length) + '.mp4')

    else:
        record_gif('futbol-v1', model, video_length=record_length, prefix=prefix)
def test_illegal_move_limit_non_driving_pelican():
    def handler(signum, frame):
        raise TimeoutError("Timed out after 30 seconds. This probably means an infinite loop of illegal moves was allowed")

    # Raise an exception if this method gets stuck
    signal.signal(signal.SIGALRM, handler)
    signal.alarm(30)
    
    try:
        # env = plark_env.PlarkEnv(driving_agent='panther',pelican_agent_filepath='/Components/plark-game/plark_game/agents/basic/PelicanAgentIllegalMove.py', pelican_agent_name='PelicanAgentIllegalMove',config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced_panther_multi_move.json')
        env = gym.make('plark-env-v0', driving_agent='panther',pelican_agent_filepath='/Components/plark-game/plark_game/agents/basic/PelicanAgentIllegalMove.py', pelican_agent_name='PelicanAgentIllegalMove',config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced_panther_multi_move.json')  
        check_env(env)
        env.reset()
        game = env.env.activeGames[-1]
        pelicanAgent = game.pelicanAgent
        pelicanAgent.reset_moves_taken()

        assert game.pelicanPlayer.row == 0, "Pelican should start at the top of the map"

        assert pelicanAgent.moves_taken == 0, "Pelican should not have taken any moves before start of game"

        for turn in [0, 1]:
            # Make moves until turn is about to end
            for move in range(4):
                # Alternately move up/down
                action = 0 if move % 2 == 0 else 3
                obs, reward, done, info = env.step(action)
                assert info['turn'] == turn, "Should still be on turn {}".format(turn)
                assert game.illegal_pelican_move == True, "Pelican should have made an illegal move"
                assert pelicanAgent.moves_taken == (turn+1)*game.max_illegal_moves_per_turn, "Pelican should have exhausted illegal moves for this turn"
                assert game.pelicanPlayer.row == 0, "Pelican should remain at the top of the map"
            
            # Next move should end turn
            obs, reward, done, info = env.step(0)
            assert info['turn'] == turn + 1, "Turn should have ended"
            assert pelicanAgent.moves_taken == (turn+1)*game.max_illegal_moves_per_turn, "pelican should have exhausted illegal moves this turn"
            assert game.pelicanPlayer.row == 0, "Pelican should remain at the top of the map"
    finally:
        # Cancel timer if test has finished executing (successful or not)
        signal.alarm(0)
def test_env_result():
    panther_agent_filepath = '/data/agents/models/test_20200325_184254/PPO2_20200325_184254_panther/' #This model is in minio or the package.
    # env = plark_env.PlarkEnv(driving_agent='pelican',panther_agent_filepath=panther_agent_filepath,config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json')
    env = gym.make('plark-env-v0', driving_agent='pelican',panther_agent_filepath=panther_agent_filepath)    
    check_env(env)

    training_steps = 500
    model = PPO2('CnnPolicy', env)
    model.learn(training_steps)

    n_eval_episodes = 5
    for _ in range(n_eval_episodes):
        obs = env.reset()
        done, state = False, None
        episode_reward = 0.0
        episode_length = 0
        victory = False
        while not done:
            action, state = model.predict(obs, state=state, deterministic=False)
            obs, reward, done, _info = env.step(action)
            if done:
                assert 'result' in _info, "Info should contain result when game is done"
                assert _info['result'] in ["WIN", "LOSE"], "result should be WIN or LOSE"
Exemple #23
0
def main():
    #logging.basicConfig(level=logging.DEBUG)
    # Create the environment
    ENV_NAME = "MineRLTreechop-v0"
    env = gym.make(ENV_NAME)
    env.action_space = 
    print(check_env(env))


    # Define the model
    model = A2C(CnnPolicy, env, verbose=1)

    # Train the agent
    model.learn(total_timesteps=8000)
Exemple #24
0
def test_identity(model_name):
    """
    Test if the algorithm (with a given policy)
    can learn an identity transformation (i.e. return observation as an action)
    :param model_name: (str) Name of the RL model
    """
    env = DummyVecEnv([lambda: IdentityEnv(18, 18, 60)])
    print('安装监督前', env)
    check_env(env, warn=True)
    print('已经检查')
    # env = Monitor(env, log_dir)
    print('安装监督后', env)
    # episode_lengths=Monitorr.get_episode_lengths()
    # print(episode_lengths)
    print('目前只有建设环境,并安装监督')
    model = LEARN_FUNC_DICT[model_name](env)
    print('目前已经完成了训练过程,接下来是评估过程')
    mean_reward, std_reward = evaluate_policy(model,
                                              env,
                                              n_eval_episodes=20,
                                              reward_threshold=None,
                                              return_episode_rewards=False)
    print('目前已经完成了评估过程')
    obs = env.reset()
    assert model.action_probability(obs).shape == (
        1, 18), "Error: action_probability not returning correct shape"
    action = env.action_space.sample()
    action_prob = model.action_probability(obs, actions=action)
    assert np.prod(action_prob.shape) == 1, "Error: not scalar probability"
    action_logprob = model.action_probability(obs, actions=action, logp=True)
    assert np.allclose(action_prob,
                       np.exp(action_logprob)), (action_prob, action_logprob)

    # Free memory

    del model, env
    return mean_reward, std_reward
Exemple #25
0
def main():
    env = gym.make('Skewb-v0')
    check_env(env, warn=True)

    model = DQN(MlpPolicy,
                env,
                verbose=1,
                learning_rate=0.001,
                buffer_size=50000,
                exploration_fraction=0.3,
                exploration_final_eps=0.01,
                train_freq=4,
                learning_starts=10000,
                target_network_update_freq=1000,
                gamma=0.99,
                prioritized_replay=True,
                prioritized_replay_alpha=0.6).learn(total_timesteps=2000000,
                                                    log_interval=800)  #,
    #callback=create_callback_func())

    print(model.learning_rate)

    print("Saving...")
    model.save("model.pkl")
Exemple #26
0
    print(f"Text: {env.current_sample.text}")
    print(f"Predicted Label {actions}")
    print(f"Oracle Label: {env.current_sample.label}")
    print(f"Total Reward: {total_reward}")
    print("---------------------------------------------")


# data pool
pool = ReutersDataPool.prepare(split="train")
labels = pool.labels()

# reward function
reward_fn = F1RewardFunction()

# multi label env
env = MultiLabelEnv(possible_labels=labels, max_steps=10, reward_function=reward_fn,
                    return_obs_as_vector=True)
for sample, weight in pool:
    env.add_sample(sample, weight)

# check the environment
check_env(env, warn=True)

# train a MLP Policy
model = DQN(env=env, policy=DQNPolicy, gamma=0.99, batch_size=32, learning_rate=1e-3,
            double_q=True, exploration_fraction=0.1,
            prioritized_replay=False, policy_kwargs={"layers": [200]},
            verbose=1)
for i in range(int(1e+3)):
    model.learn(total_timesteps=int(1e+3), reset_num_timesteps=False)
    eval_model(model, env)
Exemple #27
0
 def _init():
     env = gym.make(env_id)
     check_env(env)
     return env
Exemple #28
0
                                           n_env,
                                           n_steps,
                                           n_batch,
                                           n_lstm,
                                           reuse,
                                           layers=None,
                                           net_arch=net_arch,
                                           layer_norm=False,
                                           feature_extraction="mlp",
                                           **_kwargs)
        global training_sess
        training_sess = sess


env = gym.make('PccNs-v0')
check_env(env)
#AttributeError: 'SimulatedNetworkEnv' object has no attribute 'num_envs'
#print("Number of environments used for training (env.num_envs): " + str(env.num_envs))
#env = gym.make('CartPole-v0')

gamma = arg_or_default("--gamma", default=0.99)
print("gamma = %f" % gamma)
#https://github.com/hill-a/stable-baselines/#implemented-algorithms
#https://stable-baselines.readthedocs.io/en/master/guide/algos.html
#PPO1 can't be used with Recurrent
#In Algo 1 of paper: T is timesteps_per_actorbatch, M = optim_batchsize << NT.
#model = PPO1(MyLstmPolicy, env, verbose=1, schedule='constant', timesteps_per_actorbatch=8192, optim_batchsize=2048, gamma=gamma)

#nminibatches – (int) Number of training minibatches per update. For recurrent policies,
#the number of environments run in parallel should be a multiple of nminibatches.
#https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html?highlight=ppo2
Exemple #29
0
def main():
    """
    Testing gym pacman enviorment.
    """

    agent_name = "GymEnvTestAgent"
    ghosts = 4
    level_ghosts = 1
    lives = 3
    timeout = 3000

    obs_type = MultiChannelObs

    positive_rewards = True

    env = PacmanEnv(obs_type,
                    positive_rewards,
                    agent_name,
                    ghosts,
                    level_ghosts,
                    lives,
                    timeout,
                    training=False)
    env.set_env_params(EnvParams(1, 1, 'data/map2.bmp', 10))
    print("Checking environment...")
    check_env(env, warn=True)

    print("\nObservation space:", env.observation_space)
    print("Shape:", env.observation_space.shape)
    # print("Observation space high:", env.observation_space.high)
    # print("Observation space low:", env.observation_space.low)

    print("Action space:", env.action_space)

    obs = env.reset()
    done = False

    sum_rewards = 0
    action = 1  # a
    cur_x, cur_y = None, None

    while not done:
        env.render()

        x, y = env._game._pacman

        # Using agent from client example
        if x == cur_x and y == cur_y:
            if action in [1, 3]:  # ad
                action = random.choice([0, 2])
            elif action in [0, 2]:  # ws
                action = random.choice([1, 3])
        cur_x, cur_y = x, y

        print("key:", PacmanEnv.keys[action])

        obs, reward, done, info = env.step(action)

        sum_rewards += reward

        print("reward:", reward)
        print("sum_rewards:", sum_rewards)
        print("info:", info)
        print()
Exemple #30
0
def main() -> None:
    """Run OKWSB."""
    print("--- OKWSB ---")
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--starting_capital',
                        type=int,
                        default=100000,
                        help="How much capital is the training starting with")
    parser.add_argument('--training_timesteps',
                        type=int,
                        default=1000000,
                        help="How many timesteps are in the training")
    parser.add_argument("--model_name",
                        default="okwsb-pp02",
                        help="The name of the model file to write")
    parser.add_argument('--mode',
                        choices=(MODE_DATA, MODE_TRAIN, MODE_TEST, MODE_LIVE),
                        required=True,
                        help="The mode to run OKWSB in")
    parser.add_argument("--alphavantage_key",
                        required=True,
                        help="The API key for interfacing to AlphaVantage")
    parser.add_argument("--data_folder",
                        default=TRAINING_DATA_FOLDER,
                        help="The folder to store the training data in")
    parser.add_argument("--data_stock_tickers_max",
                        default=16,
                        help="The maximum amount of stock tickers")
    parser.add_argument("--data_stock_tickers",
                        required=False,
                        nargs="+",
                        help="The stock tickers to download")
    args = parser.parse_args()
    # Consolidate data
    timed_data = TimedDataLoader(args.alphavantage_key,
                                 args.data_folder,
                                 stock_tickers_max=args.data_stock_tickers_max,
                                 stock_tickers=args.data_stock_tickers)
    if not timed_data.has_data() and args.mode != MODE_DATA:
        print(
            "No data found. Try `okwsb --mode=data` to collect a local dataset."
        )
        sys.exit(1)
    if args.mode == MODE_DATA:
        timed_data.extract()
    else:
        # Validate environment
        gym.envs.registration.register(id=ENVIRONMENT_ID,
                                       entry_point='okwsb:StockEnv')
        env = gym.make(ENVIRONMENT_ID,
                       capital=args.starting_capital,
                       timed_data=timed_data,
                       playback=args.mode == MODE_TEST)
        check_env(env)
        env.reset()
        if args.mode == MODE_TRAIN:
            model = PPO2(MlpPolicy, env, verbose=1)
            model.learn(total_timesteps=args.training_timesteps)
            model.save(args.model_name)
            env.reset()
        elif args.mode == MODE_TEST:
            if not timed_data.has_data():
                timed_data.extract()
            model = PPO2.load(args.model_name)
            obs = env.reset()
            try:
                while True:
                    action, _ = model.predict(obs)
                    obs, _, done, _ = env.step(action)
                    if done:
                        env.render()
                        obs = env.reset()
            except StopIteration:
                pass
        elif args.mode == MODE_LIVE:
            pass
        env.close()