if not os.path.isdir(tensorboard_folder):
    os.makedirs(tensorboard_folder)
if not os.path.isdir(model_folder):
    os.makedirs(model_folder)

policy = ''
model_tag = ''
if len(sys.argv) > 1:
    policy = sys.argv[1]
    model_tag = '_' + sys.argv[1]

if __name__ == '__main__':
    env = SubprocVecEnv([lambda: BaseEnv() for i in range(4)])
    env = VecFrameStack(env, 3)

    model = PPO2(get_policy(policy),
                 env,
                 verbose=0,
                 nminibatches=1,
                 tensorboard_log=tensorboard_folder)
    model.learn(total_timesteps=100000000, tb_log_name='PPO2' + model_tag)

    model.save(model_folder + "PPO2" + model_tag)
    del model
    model = PPO2.load(model_folder + "PPO2" + model_tag)

    done = False
    states = None
    obs = env.reset()

    while not done:
tensorboard_folder = './tensorboard/Snake/action_mask/'
model_folder = './models/Snake/action_mask/'
if not os.path.isdir(tensorboard_folder):
    os.makedirs(tensorboard_folder)
if not os.path.isdir(model_folder):
    os.makedirs(model_folder)

policy = ''
model_tag = ''
if len(sys.argv) > 1:
    policy = sys.argv[1]
    model_tag = '_' + sys.argv[1]

env = DummyVecEnv([lambda: ActionMaskEnv(10, 10)])

model = PPO2(get_policy(policy), env, verbose=0, nminibatches=1, tensorboard_log=tensorboard_folder)
model.learn(total_timesteps=10000000, tb_log_name='PPO2' + model_tag)

model.save(model_folder + "PPO2" + model_tag)
del model
model = PPO2.load(model_folder + "PPO2" + model_tag)

done = False
states = None
action_masks = []
obs = env.reset()

while not done:
    action, states = model.predict(obs, states, action_mask=action_masks)
    obs, _, done, infos = env.step(action)
    env.render()
model_folder = './models/Bomberman/base/'
if not os.path.isdir(tensorboard_folder):
    os.makedirs(tensorboard_folder)
if not os.path.isdir(model_folder):
    os.makedirs(model_folder)

policy = ''
model_tag = ''
if len(sys.argv) > 1:
    policy = sys.argv[1]
    model_tag = '_' + sys.argv[1]

env = DummyVecEnv([lambda: BaseEnv()])
env = VecFrameStack(env, 3)

model = A2C(get_policy(policy),
            env,
            verbose=0,
            tensorboard_log=tensorboard_folder)
model.learn(total_timesteps=2500000, tb_log_name='A2C' + model_tag)

model.save(model_folder + "A2C" + model_tag)
del model
model = A2C.load(model_folder + "A2C" + model_tag)

done = False
states = None
obs = env.reset()

while not done:
    action, states = model.predict(obs, states)
tensorboard_folder = './tensorboard/Snake/base/'
model_folder = './models/Snake/base/'
if not os.path.isdir(tensorboard_folder):
    os.makedirs(tensorboard_folder)
if not os.path.isdir(model_folder):
    os.makedirs(model_folder)

policy = ''
model_tag = ''
if len(sys.argv) > 1:
    policy = sys.argv[1]
    model_tag = '_' + sys.argv[1]

env = DummyVecEnv([lambda: BaseEnv(10, 10)])

model = ACKTR(get_policy(policy),
              env,
              verbose=0,
              tensorboard_log=tensorboard_folder)
model.learn(total_timesteps=10000000, tb_log_name='ACKTR_A2C' + model_tag)

model.save(model_folder + "ACKTR_A2C" + model_tag)
del model
model = ACKTR.load(model_folder + "ACKTR_A2C" + model_tag)

done = False
states = None
obs = env.reset()

while not done:
    action, states = model.predict(obs, states)
if not os.path.isdir(tensorboard_folder):
    os.makedirs(tensorboard_folder)
if not os.path.isdir(model_folder):
    os.makedirs(model_folder)

policy = ''
model_tag = ''
if len(sys.argv) > 1:
    policy = sys.argv[1]
    model_tag = '_' + sys.argv[1]

if __name__ == '__main__':
    env = SubprocVecEnv([lambda: ActionMaskEnv() for i in range(4)])
    env = VecFrameStack(env, 3)

    model = ACKTR(get_policy(policy), env, n_steps=100, verbose=0,vf_fisher_coef=0.5 , tensorboard_log=tensorboard_folder, kfac_update=10, n_cpu_tf_sess=2, async_eigen_decomp=False)
    model.learn(total_timesteps=100000000, tb_log_name='ACKTR_A2C' + model_tag)

    model.save(model_folder + "ACKTR_A2C" + model_tag)
    del model
    model = ACKTR.load(model_folder + "ACKTR_A2C" + model_tag)

    done = False
    states = None
    action_masks = []
    obs = env.reset()

    while not done:
        action, states = model.predict(obs, states, action_mask=action_masks)
        obs, _, done, infos = env.step(action)
        env.render()
tensorboard_folder = './tensorboard/Snake/action_mask/'
model_folder = './models/Snake/action_mask/'
if not os.path.isdir(tensorboard_folder):
    os.makedirs(tensorboard_folder)
if not os.path.isdir(model_folder):
    os.makedirs(model_folder)

policy = ''
model_tag = ''
if len(sys.argv) > 1:
    policy = sys.argv[1]
    model_tag = '_' + sys.argv[1]

env = DummyVecEnv([lambda: ActionMaskEnv(10, 10)])

model = ACKTR(get_policy(policy), env, verbose=0, gae_lambda=0.95, tensorboard_log=tensorboard_folder)
model.learn(total_timesteps=10000000, tb_log_name='ACKTR_PPO2' + model_tag)

model.save(model_folder + "ACKTR_PPO2" + model_tag)
del model
model = ACKTR.load(model_folder + "ACKTR_PPO2" + model_tag)

done = False
states = None
action_masks = []
obs = env.reset()

while not done:
    action, states = model.predict(obs, states, action_mask=action_masks)
    obs, _, done, infos = env.step(action)
    env.render()