示例#1
0
def test_get_reward(base_env: BaseEnv):
    base_env.reset()
    obs, reward, done, _ = base_env.step(4)

    base_env.state.score_item = [10, 200, -500, 500]
    assert reward == -0.01
    assert base_env.get_reward() == 30
示例#2
0
def test_render(base_env: BaseEnv):
    with mock.patch.object(base_env_py.time, "sleep") as mock_sleep:
        base_env.render(10000)

    assert mock_sleep.call_args[0][0] == 10000

    with mock.patch.object(base_env_py.os, "system") as mock_system:
        base_env.render(0)

    assert mock_system.called
示例#3
0
def test_step(base_env: BaseEnv):
    base_env.reset()
    obs, reward, done, _ = base_env.step(5)
    assert reward == 0.01
    assert not done

    obs, reward, done, _ = base_env.step(2)
    assert reward == -0.01
    assert base_env.state.get_bomberman().get_direction() == 'East'

    obs, reward, done, _ = base_env.step(2)
    assert reward == -0.01
    assert base_env.state.get_bomberman().get_direction() == 'East'

    obs, reward, done, _ = base_env.step(2)
    assert reward == -0.01
    assert base_env.state.get_bomberman().get_direction() == 'East'

    obs, reward, done, _ = base_env.step(4)
    assert reward == -0.01

    obs, reward, done, _ = base_env.step(4)
    assert reward == 10

    bomb = Bomb((3, 3))
    bomb.countdown = 1
    base_env.state.get_bombs().append(bomb)

    obs, reward, done, _ = base_env.step(4)
    assert reward == 80
    assert done
from examples.utils.utils import get_policy

tensorboard_folder = './tensorboard/Bomberman/base/'
model_folder = './models/Bomberman/base/'
if not os.path.isdir(tensorboard_folder):
    os.makedirs(tensorboard_folder)
if not os.path.isdir(model_folder):
    os.makedirs(model_folder)

policy = ''
model_tag = ''
if len(sys.argv) > 1:
    policy = sys.argv[1]
    model_tag = '_' + sys.argv[1]

env = DummyVecEnv([lambda: BaseEnv()])
env = VecFrameStack(env, 3)

model = A2C(get_policy(policy),
            env,
            verbose=0,
            tensorboard_log=tensorboard_folder)
model.learn(total_timesteps=2500000, tb_log_name='A2C' + model_tag)

model.save(model_folder + "A2C" + model_tag)
del model
model = A2C.load(model_folder + "A2C" + model_tag)

done = False
states = None
obs = env.reset()
def state():
    return BaseEnv('test_map').state
def base_env():
    return BaseEnv('test_map')
示例#7
0
def test_reset(base_env: BaseEnv):
    base_env.current_step = 1
    base_env.reset()
    assert base_env.current_step == 0