def main(): tf.Session().__enter__() seed = 1 np.random.seed(seed) tf.set_random_seed(seed) # disable logging during testing if test_model: log_interval = int(1e20) else: log_interval = 1 # env = gym.make('CartPole-v0') # Create Multi-Snake environment spacing = 22 grid_dim = 10 history = 4 env = MultiSnake(num_agents=num_agents, num_fruits=3, spacing=spacing, grid_dim=grid_dim, flatten_states=False, reward_killed=-1.0,history=4,save_gif=save_gif) env.reset() env = makegymwrapper(env, visualize=test_model) ppo2_actprior.learn(policy=MultiSnakeCNNPolicy, env=env, nsteps=buf_size, nminibatches=nminibatches, lam=lam, gamma=gamma, noptepochs=4, log_interval=log_interval, ent_coef=.01, lr=1e-3, cliprange=.2, total_timesteps=int(1e20), callback_fn=on_update)
giffn = trainhistdir + 'video.gif' # handle command line arguments test_model, save_gif = handle_args(test_model, save_gif) if save_gif: num_episodes = 3 # Create custom OpenAIgym environment num_agents = 1 spacing = 22 grid_dim = 10 e = MultiSnake(num_agents=num_agents, num_fruits=3, spacing=spacing, grid_dim=grid_dim, flatten_states=False, reward_killed=-1.0, save_gif=save_gif) env = OpenAIGym_custom(e, "MultiSnake", visualize=test_model) network_spec = [ dict(type='conv2d', size=16, window=3, stride=1, bias=True), dict(type='conv2d', size=32, window=3, stride=1, bias=True), dict(type='flatten'), dict(type='dense', size=256, bias=True) ] states_preprocessing = [ # dict(type='divide',scale=2)
episode_offset = 0 max_episode_timesteps = 1000 giffn = trainhistdir + 'video.gif' # handle command line arguments test_model, save_gif = handle_args(test_model, save_gif) if save_gif: num_episodes = 3 # Create custom OpenAIgym environment num_agents = 1 spacing = 22 grid_dim = 10 e = MultiSnake(num_agents=num_agents, num_fruits=3, spacing=spacing, grid_dim=grid_dim, save_gif=save_gif) env = OpenAIGym_custom(e, "MultiSnake", visualize=test_model) # Network as list of layers network_spec = [ dict(type='dense', size=128, activation='tanh'), dict(type='dense', size=128, activation='tanh') ] agent = PPOAgent( states=env.states, actions=env.actions, network=network_spec, update_mode = dict( unit='episodes',