Пример #1
0
def make_env(config, writer, prefix, datadir, store):
    if "lunar" in config.task:
        env = wrappers.LunarLander(size=(64, 64),
                                   action_repeat=config.action_repeat)
    elif "Car" in config.task:
        env = wrappers.CarEnvWrapper(env_name=config.task,
                                     size=(64, 64),
                                     action_repeat=config.action_repeat,
                                     seed=config.seed)
    else:
        suite, task = config.task.split('_', 1)
        if suite == 'dmc':
            env = wrappers.DeepMindControl(task)
            env = wrappers.ActionRepeat(env, config.action_repeat)
            env = wrappers.NormalizeActions(env)
        elif suite == 'atari':
            env = wrappers.Atari(task,
                                 config.action_repeat, (64, 64),
                                 grayscale=False,
                                 life_done=True,
                                 sticky_actions=True)
            env = wrappers.OneHotAction(env)
        else:
            raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Пример #2
0
def make_env(config, writer, prefix, datadir, store):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        env = wrappers.DeepMindControl(task)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == 'gym':
        env = wrappers.Gym(task, config, size=(128, 128))
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif task == 'door':
        env = wrappers.DoorOpen(config, size=(128, 128))
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif task == 'drawer':
        env = wrappers.DrawerOpen(config, size=(128, 128))
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    if prefix == 'test':
        callbacks.append(
            lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Пример #3
0
def make_env(config, logger, mode, train_eps, eval_eps):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
        env = wrappers.NormalizeActions(env)
    elif suite == 'atari':
        env = wrappers.Atari(task,
                             config.action_repeat,
                             config.size,
                             grayscale=config.grayscale,
                             life_done=False and (mode == 'train'),
                             sticky_actions=True,
                             all_actions=True)
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit)
    env = wrappers.SelectAction(env, key='action')
    callbacks = [
        functools.partial(process_episode, config, logger, mode, train_eps,
                          eval_eps)
    ]
    env = wrappers.CollectDataset(env, callbacks)
    env = wrappers.RewardObs(env)
    return env
Пример #4
0
def make_env(config, writer, prefix, datadir, store):
    suite, task = config.task.split("_", 1)
    if suite == "dmc":
        env = wrappers.DeepMindControl(task)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == "atari":
        env = wrappers.Atari(
            task,
            config.action_repeat,
            (64, 64),
            grayscale=False,
            life_done=True,
            sticky_actions=True,
        )
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Пример #5
0
def make_test_env(config, writer, datadir, gui=False):
    env = make_base_env(config, gui)
    env = wrappers.FixedResetMode(env, mode='grid')
    env = wrappers.TimeLimit(env,
                             config.time_limit_test / config.action_repeat)
    # rendering
    render_callbacks = []
    render_callbacks.append(
        lambda videos: callbacks.save_videos(videos, config, datadir))
    env = wrappers.Render(env, render_callbacks)
    # summary
    callback_list = []
    callback_list.append(lambda episodes: callbacks.summarize_episode(
        episodes, config, datadir, writer, 'test'))
    env = wrappers.Collect(env, callback_list, config.precision)
    return env
Пример #6
0
def make_env(config,
             writer,
             prefix,
             datadir,
             store,
             index=None,
             real_world=False):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        if config.dr is None or real_world:  #first index is always real world
            env = wrappers.DeepMindControl(task,
                                           use_state=config.use_state,
                                           real_world=real_world)
        else:
            env = wrappers.DeepMindControl(task,
                                           dr=config.dr,
                                           use_state=config.use_state,
                                           real_world=real_world)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == 'atari':
        env = wrappers.Atari(task,
                             config.action_repeat, (64, 64),
                             grayscale=False,
                             life_done=True,
                             sticky_actions=True)
        env = wrappers.OneHotAction(env)
    elif suite == 'gym':
        if index == 0 or index is None:  #first index is always real world
            env = wrappers.GymControl(task)
        else:
            env = wrappers.GymControl(task, dr=config.dr)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)

    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Пример #7
0
def make_train_env(config, writer, datadir, gui=False):
    env = make_base_env(config, gui)
    if env.n_agents > 1:
        env = wrappers.FixedResetMode(
            env,
            mode='random_ball')  # sample in random points close within a ball
    else:
        env = wrappers.FixedResetMode(env, mode='random')
    env = wrappers.TimeLimit(env,
                             config.time_limit_train / config.action_repeat)
    # storing and summary
    callback_list = []
    callback_list.append(
        lambda episodes: callbacks.save_episodes(datadir, episodes))
    callback_list.append(lambda episodes: callbacks.summarize_episode(
        episodes, config, datadir, writer, 'train'))
    env = wrappers.Collect(env, callback_list, config.precision)
    return env
Пример #8
0
def make_gridworld_env(config, writer, prefix, datadir, store, desc=None):
    ##NOTE:CHANEG:new function
    # dreamer设计的env接口长什么样

    suite, task = config.task.split('_', 1)
    if suite == 'gridworld':
        env = wrappers.GridWorld(desc)
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    # collect here
    env = wrappers.RewardObs(env)
    return env
Пример #9
0
def make_env(config, writer, prefix, datadir, store):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        env = wrappers.DeepMindControl(task)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == 'atari':
        env = wrappers.Atari(task,
                             config.action_repeat, (64, 64),
                             grayscale=False,
                             life_done=True,
                             sticky_actions=True)
        env = wrappers.OneHotAction(env)
    elif suite == 'football':
        env = football_env.create_environment(
            representation='pixels',
            env_name='academy_empty_goal_close',
            stacked=False,
            logdir='./football/empty_goal_close2',
            write_goal_dumps=True,
            write_full_episode_dumps=True,
            render=True,
            write_video=True)
        env = wrappers.Football(env)
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Пример #10
0
import gym
import wrappers
import numpy as np

task = 'SingleAgentTreitlstrasse_v2-v0'
time_limit = 60 * 100
action_repeat = 8
env = gym.make(task)
env = wrappers.TimeLimit(env, time_limit)
env = wrappers.ActionRepeat(env, action_repeat)


def test_on_track(model, outdir):
    video, returns = simulate_episode(model)
    videodir = outdir / 'videos'
    videodir.mkdir(parents=True, exist_ok=True)
    import imageio
    writer = imageio.get_writer(videodir / f'test_return{returns}.mp4')
    for image in video:
        writer.append_data(image)
    writer.close()


def simulate_episode(model, prediction_window=5, terminate_on_collision=True):
    # to do: make it uniform to f1_tenth directory
    done = False
    obs = env.reset(mode='grid')
    state = None
    video = []
    returns = 0.0
    while not done: