コード例 #1
0
def main():
    args = arg_parser().parse_args()
    print('Creating environments...')
    env = create_env(args.env, 1, 1, args.fps, args.max_timesteps)
    env = ObsInInfo(env)
    env = wrap_env(env)
    try:
        print('Creating session...')
        with tf.Session() as sess:
            print('Creating model graph...')
            model = IMPALAModel(sess, *gym_spaces(env))
            print('Initializing model variables...')
            sess.run(tf.global_variables_initializer())
            load_vars(sess, args.save_path)
            print('Gathering episode...')

            def run_episode():
                states = model.start_state(1)
                env.reset_start()
                obses = env.reset_wait()
                while True:
                    outputs = model.step(obses, states)
                    env.step_start(outputs['actions'])
                    obses, _, dones, infos = env.step_wait()
                    states = outputs['states']
                    yield pad_dims(infos[0]['old_obs'])
                    if dones[0]:
                        return
            spec = muniverse.spec_for_name(args.env)
            export_video(args.path, padded_dim(spec['Width']), padded_dim(spec['Height']), args.fps,
                         run_episode())
    finally:
        env.close()
コード例 #2
0
ファイル: export.py プロジェクト: wwxFromTju/obs-tower2
def main():
    def image_fn():
        env = StateEnv(create_single_env(random.randrange(15, 20), clear=False))
        model = ACModel()
        tail_model = ACModel()
        model.load_state_dict(torch.load('save.pkl', map_location='cpu'))
        tail_model.load_state_dict(torch.load('save_tail.pkl', map_location='cpu'))
        model.to(torch.device('cuda'))
        tail_model.to(torch.device('cuda'))
        state, obs = env.reset()
        floor = 0
        while True:
            output = (model if floor < 10 else tail_model).step(np.array([state]), np.array([obs]))
            (state, obs), rew, done, info = env.step(output['actions'][0])
            floor = info['current_floor']
            yield big_obs(obs[..., -3:], info)
            if done:
                break
        env.close()

    export_video('export.mp4', 168, 168, 10, image_fn())
コード例 #3
0
ファイル: stuck_box.py プロジェクト: wwxFromTju/obs-tower2
import json
import os

from anyrl.utils.ffmpeg import export_video
import numpy as np
from obstacle_tower_env import ObstacleTowerEnv

from obs_tower2.util import big_obs

with open('stuck_box.json', 'r') as in_file:
    data = json.load(in_file)

env = ObstacleTowerEnv(os.environ['OBS_TOWER_PATH'], worker_id=1)
env.seed(56)
env.reset()


def f():
    for i, act in enumerate(data):
        obs, _, _, info = env.step(act)
        if i > 5275:
            yield big_obs(obs, info)


export_video('stuck_box.mp4', 168, 168, 10, f())
コード例 #4
0
import json
import os

from anyrl.utils.ffmpeg import export_video
import numpy as np
from obstacle_tower_env import ObstacleTowerEnv

from obs_tower2.util import big_obs

with open('hang.json', 'r') as in_file:
    data = json.load(in_file)

env = ObstacleTowerEnv(os.environ['OBS_TOWER_PATH'], worker_id=4)
env.reset()


def f():
    env.seed(data['seed'])
    env.floor(data['floor'])
    env.reset()
    for act in data['actions'][:-1]:
        obs, _, _, info = env.step(act)
        yield big_obs(obs, info)


export_video('hang.mp4', 168, 168, 10, f())