def make_env_with_constraints(env_id, novelty_family=None, inject=False): # print () env2 = gym.make(env_id) env2.unbreakable_items.add( 'crafting_table' ) # Make crafting table unbreakable for easy solving of task. env2 = LimitActions(env2, {'Forward', 'Left', 'Right', 'Break', 'Craft_bow' }) # limit actions for easy training env2 = LidarInFront( env2) # generate the observation space using LIDAR sensors # print(env.unbreakable_items) env2.reward_done = 1000 env2.reward_intermediate = 50 if inject: env2 = inject_novelty(env2, novelty_family[0], novelty_family[1], novelty_family[2], novelty_family[3]) if novelty_family[0] == 'breakincrease': print( "Break increase novelty injected::: self.env.itemtobreakmore = {}" .format(env2.itemtobreakmore)) if novelty_family[0] == 'remapaction': print( "Action remap novelty check:: self.env.limited_actions_id = {}" .format(env2.limited_actions_id)) # if novelty_family[0] check_env(env2, warn=True) # check the environment return env2
def main(): dm, y_oracle = init_dm(CONFIG) print(dm) cnn = ConvNet(n_output=2) env = QueryEnv(dm, cnn, CONFIG) env = MonitorWrapper(env, autolog=True) check_env(env, warn=True) model = load_dqn_model('data/rl_rps.pth') n_queries = 10 n_epochs_cnn = 8 for k in range(n_queries): print(dm) done = False obs = env.reset() while not done: action, _states = model.predict(obs) obs, reward, done, info = env.step(action) query_indicies = env.get_query_indicies() dm.label_samples(query_indicies, y_oracle[query_indicies]) y_oracle = np.delete(y_oracle, query_indicies, axis=0) cnn.fit(*dm.train.get_xy(), n_epochs_cnn, batch_size=32) cnn.evaluate(*dm.eval.get_xy()) cnn.evaluate(*dm.test.get_xy())
def test_env_init(file_array, stock_number, action_space): from stable_baselines.common.env_checker import check_env env = StockTradingEnvV1(df=None, file_array=file_array) check_env(env) assert env.action_space == spaces.MultiDiscrete(action_space) assert env.stock_number == stock_number assert env.reset().shape == env.observation_space.shape assert env.obs_column_number == 1 * stock_number
def main(): """ Tests gym-compatibility of envs """ env = PokeEnv() check_env(env) print('Test complete.') return True
def test_stable_baselines_check(self): '''Use the environment checker from stable baselines to test the environment. This checks that the environment follows the Gym API. It also optionally checks that the environment is compatible with Stable-Baselines repository. ''' check_env(self.env, warn=True)
def test_id_env(): env = IdEnv(max_steps=6) check_env(env) assert env.reset()[0] == 0 collector = EnvDataCollector(env) episode(collector) collector.flush() assert len(collector.raw_data[0]) == 7
def check_environment(env_name): from gym_mimic_envs.monitor import Monitor as EnvMonitor from stable_baselines.common.env_checker import check_env env = gym.make(env_name) log('Checking custom environment') check_env(env) env = EnvMonitor(env) log('Checking custom env in custom monitor wrapper') check_env(env) exit(33)
def test_env_and_pre_trained_agent(): panther_agent_filepath = '/data/agents/models/test_20200325_184254/PPO2_20200325_184254_panther/' #This model is in minio or the package. # env = plark_env.PlarkEnv(driving_agent='pelican',panther_agent_filepath=panther_agent_filepath,config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json') env = gym.make('plark-env-v0', driving_agent='pelican',panther_agent_filepath=panther_agent_filepath) check_env(env) training_steps = 500 model = PPO2('CnnPolicy', env) model.learn(training_steps) n_eval_episodes = 2 mean_reward, n_steps, victories = helper.evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False, render=False, callback=None, reward_threshold=None, return_episode_rewards=False)
def test_non_default_spaces(new_obs_space): env = gym.make('BreakoutNoFrameskip-v4') env.observation_space = new_obs_space # Patch methods to avoid errors env.reset = new_obs_space.sample def patched_step(_action): return new_obs_space.sample(), 0.0, False, {} env.step = patched_step with pytest.warns(UserWarning): check_env(env)
def __init__(self, model: abc.ABCMeta): # Define new environment self.env = RaspEnv() # Check if environment is ok check_env(self.env) # Define empty model self.model = model # Configure model hyperparameters self.config = { 'learning_starts': 32, 'target_network_update_freq': 100, 'learning_rate': 0.001 }
def check_reset_assert_error(env, new_reset_return): """ Helper to check that the error is caught. :param env: (gym.Env) :param new_reset_return: (Any) """ def wrong_reset(): return new_reset_return # Patch the reset method with a wrong one env.reset = wrong_reset with pytest.raises(AssertionError): check_env(env)
def check_step_assert_error(env, new_step_return=()): """ Helper to check that the error is caught. :param env: (gym.Env) :param new_step_return: (tuple) """ def wrong_step(_action): return new_step_return # Patch the step method with a wrong one env.step = wrong_step with pytest.raises(AssertionError): check_env(env)
def test_env_and_agent(): # env = plark_env.PlarkEnv(config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json') env = gym.make('plark-env-v0') check_env(env) training_steps = 10 model = PPO2('CnnPolicy', env) model.learn(training_steps) n_eval_episodes = 2 mean_reward, n_steps, victories = helper.evaluate_policy(model, env, n_eval_episodes=n_eval_episodes, deterministic=False, render=False, callback=None, reward_threshold=None, return_episode_rewards=False) assert "PlarkEnv" in str(type(env)) assert "PPO2" in str(type(model))
def test_normalize_false(): from stable_baselines.common.env_checker import check_env env = StockTradingEnvV1(debug=True, normalize_observation=False) check_env(env) assert env.observation_space.shape == (14, ) obs = env.reset() stacked_obs_part = np.reshape( obs[:env.obs_column_number * env.observation_frame], (-1, env.stock_number)) index_end = env.current_step + env.start_point index_start = env.current_step + env.start_point - env.observation_frame + 1 ref_obs = env.df[ env.observation_column_name_array].loc[index_start:index_end].values assert ((ref_obs == stacked_obs_part).all()).all()
def test_illegal_move_limit_driving_pelican(): panther_agent_filepath = '/Components/plark-game/plark_game/agents/basic/pantherAgent_move_north.py' #This model is in minio or the package. # env = plark_env.PlarkEnv(driving_agent='pelican',panther_agent_filepath=panther_agent_filepath, panther_agent_name='Panther_Agent_Move_North',config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json') env = gym.make('plark-env-v0', driving_agent='pelican',panther_agent_filepath=panther_agent_filepath, panther_agent_name='Panther_Agent_Move_North') check_env(env) env.reset() # Repeat illegal move (illegal move limit - 1) times for i in range(9): obs, reward, done, info = env.step(0) assert info['illegal_move'] == True, "Should have made illegal move" assert info['turn'] == 0, "Should still be on first turn" # Next illegal move should end turn obs, reward, done, info = env.step(0) assert info['illegal_move'] == True assert info['turn'] == 1
def test_illegal_move_limit_driving_panther(): pelican_agent_filepath = '/data/agents/models/test_20200325_184254/PPO2_20200325_184254_pelican/' #This model is in minio or the package. # env = plark_env.PlarkEnv(driving_agent='panther',pelican_agent_filepath=pelican_agent_filepath,config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json') env = gym.make('plark-env-v0', driving_agent='panther',pelican_agent_filepath=pelican_agent_filepath) check_env(env) env.reset() # Repeat illegal move (illegal move limit - 1) times for i in range(9): obs, reward, done, info = env.step(3) assert info['illegal_move'] == True, "Should have made illegal move" assert info['turn'] == 0, "Should still be on first turn" # Next illegal move should end turn obs, reward, done, info = env.step(0) assert info['illegal_move'] == False assert info['turn'] == 1
def test_env(env_id): """ Check that environmnent integrated in Gym pass the test. :param env_id: (str) """ env = gym.make(env_id) with pytest.warns(None) as record: check_env(env) # Pendulum-v0 will produce a warning because the action space is # in [-2, 2] and not [-1, 1] if env_id == 'Pendulum-v0': assert len(record) == 1 else: # The other environments must pass without warning assert len(record) == 0
def test_high_dimension_action_space(): """ Test for continuous action space with more than one action. """ env = gym.make('Pendulum-v0') # Patch the action space env.action_space = spaces.Box(low=-1, high=1, shape=(20, ), dtype=np.float32) # Patch to avoid error def patched_step(_action): return env.observation_space.sample(), 0.0, False, {} env.step = patched_step check_env(env)
def baseline_test(): from stable_baselines import PPO2 from stable_baselines.common.policies import MlpPolicy from stable_baselines.common import env_checker gym_env_fn = lambda: Gym_procgen_continuous(env_name = 'fruitbot') env_checker.check_env(gym_env_fn()) print('Success! : environment is compatible with stable baselines') env = make_vec_env(gym_env_fn, n_envs = 64) model = PPO2(MlpPolicy, env, verbose=1) model.learn(total_timesteps=100000) print('Success! : stable baselines trained on the vectorized environment')
def main(video=True, title=False): env = gym.make("futbol-v1") check_env(env, warn=True) file_name = 'left-ppo2-lstm-2v2-5e3' model = PPO2.load("zoo/2v2/" + file_name) # file_name = 'ppo2-futbol-10M_best_model' # model = PPO2.load("supplement/" + file_name) prefix = file_name record_length = 300 if video: if title: record_video_with_title('futbol-v1', model, prefix=prefix) show_video('videos/' + prefix + '.mp4') else: record_video('futbol-v1', model, video_length=record_length, prefix=prefix, lstm=True) show_video('videos/' + prefix + '-step-0-to-step-' + str(record_length) + '.mp4') else: record_gif('futbol-v1', model, video_length=record_length, prefix=prefix)
def test_illegal_move_limit_non_driving_pelican(): def handler(signum, frame): raise TimeoutError("Timed out after 30 seconds. This probably means an infinite loop of illegal moves was allowed") # Raise an exception if this method gets stuck signal.signal(signal.SIGALRM, handler) signal.alarm(30) try: # env = plark_env.PlarkEnv(driving_agent='panther',pelican_agent_filepath='/Components/plark-game/plark_game/agents/basic/PelicanAgentIllegalMove.py', pelican_agent_name='PelicanAgentIllegalMove',config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced_panther_multi_move.json') env = gym.make('plark-env-v0', driving_agent='panther',pelican_agent_filepath='/Components/plark-game/plark_game/agents/basic/PelicanAgentIllegalMove.py', pelican_agent_name='PelicanAgentIllegalMove',config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced_panther_multi_move.json') check_env(env) env.reset() game = env.env.activeGames[-1] pelicanAgent = game.pelicanAgent pelicanAgent.reset_moves_taken() assert game.pelicanPlayer.row == 0, "Pelican should start at the top of the map" assert pelicanAgent.moves_taken == 0, "Pelican should not have taken any moves before start of game" for turn in [0, 1]: # Make moves until turn is about to end for move in range(4): # Alternately move up/down action = 0 if move % 2 == 0 else 3 obs, reward, done, info = env.step(action) assert info['turn'] == turn, "Should still be on turn {}".format(turn) assert game.illegal_pelican_move == True, "Pelican should have made an illegal move" assert pelicanAgent.moves_taken == (turn+1)*game.max_illegal_moves_per_turn, "Pelican should have exhausted illegal moves for this turn" assert game.pelicanPlayer.row == 0, "Pelican should remain at the top of the map" # Next move should end turn obs, reward, done, info = env.step(0) assert info['turn'] == turn + 1, "Turn should have ended" assert pelicanAgent.moves_taken == (turn+1)*game.max_illegal_moves_per_turn, "pelican should have exhausted illegal moves this turn" assert game.pelicanPlayer.row == 0, "Pelican should remain at the top of the map" finally: # Cancel timer if test has finished executing (successful or not) signal.alarm(0)
def test_env_result(): panther_agent_filepath = '/data/agents/models/test_20200325_184254/PPO2_20200325_184254_panther/' #This model is in minio or the package. # env = plark_env.PlarkEnv(driving_agent='pelican',panther_agent_filepath=panther_agent_filepath,config_file_path='/Components/plark-game/plark_game/game_config/10x10/balanced.json') env = gym.make('plark-env-v0', driving_agent='pelican',panther_agent_filepath=panther_agent_filepath) check_env(env) training_steps = 500 model = PPO2('CnnPolicy', env) model.learn(training_steps) n_eval_episodes = 5 for _ in range(n_eval_episodes): obs = env.reset() done, state = False, None episode_reward = 0.0 episode_length = 0 victory = False while not done: action, state = model.predict(obs, state=state, deterministic=False) obs, reward, done, _info = env.step(action) if done: assert 'result' in _info, "Info should contain result when game is done" assert _info['result'] in ["WIN", "LOSE"], "result should be WIN or LOSE"
def main(): #logging.basicConfig(level=logging.DEBUG) # Create the environment ENV_NAME = "MineRLTreechop-v0" env = gym.make(ENV_NAME) env.action_space = print(check_env(env)) # Define the model model = A2C(CnnPolicy, env, verbose=1) # Train the agent model.learn(total_timesteps=8000)
def test_identity(model_name): """ Test if the algorithm (with a given policy) can learn an identity transformation (i.e. return observation as an action) :param model_name: (str) Name of the RL model """ env = DummyVecEnv([lambda: IdentityEnv(18, 18, 60)]) print('安装监督前', env) check_env(env, warn=True) print('已经检查') # env = Monitor(env, log_dir) print('安装监督后', env) # episode_lengths=Monitorr.get_episode_lengths() # print(episode_lengths) print('目前只有建设环境,并安装监督') model = LEARN_FUNC_DICT[model_name](env) print('目前已经完成了训练过程,接下来是评估过程') mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=None, return_episode_rewards=False) print('目前已经完成了评估过程') obs = env.reset() assert model.action_probability(obs).shape == ( 1, 18), "Error: action_probability not returning correct shape" action = env.action_space.sample() action_prob = model.action_probability(obs, actions=action) assert np.prod(action_prob.shape) == 1, "Error: not scalar probability" action_logprob = model.action_probability(obs, actions=action, logp=True) assert np.allclose(action_prob, np.exp(action_logprob)), (action_prob, action_logprob) # Free memory del model, env return mean_reward, std_reward
def main(): env = gym.make('Skewb-v0') check_env(env, warn=True) model = DQN(MlpPolicy, env, verbose=1, learning_rate=0.001, buffer_size=50000, exploration_fraction=0.3, exploration_final_eps=0.01, train_freq=4, learning_starts=10000, target_network_update_freq=1000, gamma=0.99, prioritized_replay=True, prioritized_replay_alpha=0.6).learn(total_timesteps=2000000, log_interval=800) #, #callback=create_callback_func()) print(model.learning_rate) print("Saving...") model.save("model.pkl")
print(f"Text: {env.current_sample.text}") print(f"Predicted Label {actions}") print(f"Oracle Label: {env.current_sample.label}") print(f"Total Reward: {total_reward}") print("---------------------------------------------") # data pool pool = ReutersDataPool.prepare(split="train") labels = pool.labels() # reward function reward_fn = F1RewardFunction() # multi label env env = MultiLabelEnv(possible_labels=labels, max_steps=10, reward_function=reward_fn, return_obs_as_vector=True) for sample, weight in pool: env.add_sample(sample, weight) # check the environment check_env(env, warn=True) # train a MLP Policy model = DQN(env=env, policy=DQNPolicy, gamma=0.99, batch_size=32, learning_rate=1e-3, double_q=True, exploration_fraction=0.1, prioritized_replay=False, policy_kwargs={"layers": [200]}, verbose=1) for i in range(int(1e+3)): model.learn(total_timesteps=int(1e+3), reset_num_timesteps=False) eval_model(model, env)
def _init(): env = gym.make(env_id) check_env(env) return env
n_env, n_steps, n_batch, n_lstm, reuse, layers=None, net_arch=net_arch, layer_norm=False, feature_extraction="mlp", **_kwargs) global training_sess training_sess = sess env = gym.make('PccNs-v0') check_env(env) #AttributeError: 'SimulatedNetworkEnv' object has no attribute 'num_envs' #print("Number of environments used for training (env.num_envs): " + str(env.num_envs)) #env = gym.make('CartPole-v0') gamma = arg_or_default("--gamma", default=0.99) print("gamma = %f" % gamma) #https://github.com/hill-a/stable-baselines/#implemented-algorithms #https://stable-baselines.readthedocs.io/en/master/guide/algos.html #PPO1 can't be used with Recurrent #In Algo 1 of paper: T is timesteps_per_actorbatch, M = optim_batchsize << NT. #model = PPO1(MyLstmPolicy, env, verbose=1, schedule='constant', timesteps_per_actorbatch=8192, optim_batchsize=2048, gamma=gamma) #nminibatches – (int) Number of training minibatches per update. For recurrent policies, #the number of environments run in parallel should be a multiple of nminibatches. #https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html?highlight=ppo2
def main(): """ Testing gym pacman enviorment. """ agent_name = "GymEnvTestAgent" ghosts = 4 level_ghosts = 1 lives = 3 timeout = 3000 obs_type = MultiChannelObs positive_rewards = True env = PacmanEnv(obs_type, positive_rewards, agent_name, ghosts, level_ghosts, lives, timeout, training=False) env.set_env_params(EnvParams(1, 1, 'data/map2.bmp', 10)) print("Checking environment...") check_env(env, warn=True) print("\nObservation space:", env.observation_space) print("Shape:", env.observation_space.shape) # print("Observation space high:", env.observation_space.high) # print("Observation space low:", env.observation_space.low) print("Action space:", env.action_space) obs = env.reset() done = False sum_rewards = 0 action = 1 # a cur_x, cur_y = None, None while not done: env.render() x, y = env._game._pacman # Using agent from client example if x == cur_x and y == cur_y: if action in [1, 3]: # ad action = random.choice([0, 2]) elif action in [0, 2]: # ws action = random.choice([1, 3]) cur_x, cur_y = x, y print("key:", PacmanEnv.keys[action]) obs, reward, done, info = env.step(action) sum_rewards += reward print("reward:", reward) print("sum_rewards:", sum_rewards) print("info:", info) print()
def main() -> None: """Run OKWSB.""" print("--- OKWSB ---") tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # Parse command line arguments parser = argparse.ArgumentParser() parser.add_argument('--starting_capital', type=int, default=100000, help="How much capital is the training starting with") parser.add_argument('--training_timesteps', type=int, default=1000000, help="How many timesteps are in the training") parser.add_argument("--model_name", default="okwsb-pp02", help="The name of the model file to write") parser.add_argument('--mode', choices=(MODE_DATA, MODE_TRAIN, MODE_TEST, MODE_LIVE), required=True, help="The mode to run OKWSB in") parser.add_argument("--alphavantage_key", required=True, help="The API key for interfacing to AlphaVantage") parser.add_argument("--data_folder", default=TRAINING_DATA_FOLDER, help="The folder to store the training data in") parser.add_argument("--data_stock_tickers_max", default=16, help="The maximum amount of stock tickers") parser.add_argument("--data_stock_tickers", required=False, nargs="+", help="The stock tickers to download") args = parser.parse_args() # Consolidate data timed_data = TimedDataLoader(args.alphavantage_key, args.data_folder, stock_tickers_max=args.data_stock_tickers_max, stock_tickers=args.data_stock_tickers) if not timed_data.has_data() and args.mode != MODE_DATA: print( "No data found. Try `okwsb --mode=data` to collect a local dataset." ) sys.exit(1) if args.mode == MODE_DATA: timed_data.extract() else: # Validate environment gym.envs.registration.register(id=ENVIRONMENT_ID, entry_point='okwsb:StockEnv') env = gym.make(ENVIRONMENT_ID, capital=args.starting_capital, timed_data=timed_data, playback=args.mode == MODE_TEST) check_env(env) env.reset() if args.mode == MODE_TRAIN: model = PPO2(MlpPolicy, env, verbose=1) model.learn(total_timesteps=args.training_timesteps) model.save(args.model_name) env.reset() elif args.mode == MODE_TEST: if not timed_data.has_data(): timed_data.extract() model = PPO2.load(args.model_name) obs = env.reset() try: while True: action, _ = model.predict(obs) obs, _, done, _ = env.step(action) if done: env.render() obs = env.reset() except StopIteration: pass elif args.mode == MODE_LIVE: pass env.close()