예제 #1
0
파일: dreamer.py 프로젝트: roggirg/dreamer
def make_env(config, writer, prefix, datadir, store):
    if "lunar" in config.task:
        env = wrappers.LunarLander(size=(64, 64),
                                   action_repeat=config.action_repeat)
    elif "Car" in config.task:
        env = wrappers.CarEnvWrapper(env_name=config.task,
                                     size=(64, 64),
                                     action_repeat=config.action_repeat,
                                     seed=config.seed)
    else:
        suite, task = config.task.split('_', 1)
        if suite == 'dmc':
            env = wrappers.DeepMindControl(task)
            env = wrappers.ActionRepeat(env, config.action_repeat)
            env = wrappers.NormalizeActions(env)
        elif suite == 'atari':
            env = wrappers.Atari(task,
                                 config.action_repeat, (64, 64),
                                 grayscale=False,
                                 life_done=True,
                                 sticky_actions=True)
            env = wrappers.OneHotAction(env)
        else:
            raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
예제 #2
0
def make_env(config, logger, mode, train_eps, eval_eps):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        env = wrappers.DeepMindControl(task, config.action_repeat, config.size)
        env = wrappers.NormalizeActions(env)
    elif suite == 'atari':
        env = wrappers.Atari(task,
                             config.action_repeat,
                             config.size,
                             grayscale=config.grayscale,
                             life_done=False and (mode == 'train'),
                             sticky_actions=True,
                             all_actions=True)
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit)
    env = wrappers.SelectAction(env, key='action')
    callbacks = [
        functools.partial(process_episode, config, logger, mode, train_eps,
                          eval_eps)
    ]
    env = wrappers.CollectDataset(env, callbacks)
    env = wrappers.RewardObs(env)
    return env
예제 #3
0
def make_env(config, writer, prefix, datadir, store):
    suite, task = config.task.split("_", 1)
    if suite == "dmc":
        env = wrappers.DeepMindControl(task)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == "atari":
        env = wrappers.Atari(
            task,
            config.action_repeat,
            (64, 64),
            grayscale=False,
            life_done=True,
            sticky_actions=True,
        )
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
예제 #4
0
def make_env(config, writer, prefix, datadir, store):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        env = wrappers.DeepMindControl(task)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == 'atari':
        env = wrappers.Atari(task,
                             config.action_repeat, (64, 64),
                             grayscale=False,
                             life_done=True,
                             sticky_actions=True)
        env = wrappers.OneHotAction(env)
    elif suite == 'football':
        env = football_env.create_environment(
            representation='pixels',
            env_name='academy_empty_goal_close',
            stacked=False,
            logdir='./football/empty_goal_close2',
            write_goal_dumps=True,
            write_full_episode_dumps=True,
            render=True,
            write_video=True)
        env = wrappers.Football(env)
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
예제 #5
0
def make_env(config,
             writer,
             prefix,
             datadir,
             store,
             index=None,
             real_world=False):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        if config.dr is None or real_world:  #first index is always real world
            env = wrappers.DeepMindControl(task,
                                           use_state=config.use_state,
                                           real_world=real_world)
        else:
            env = wrappers.DeepMindControl(task,
                                           dr=config.dr,
                                           use_state=config.use_state,
                                           real_world=real_world)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == 'atari':
        env = wrappers.Atari(task,
                             config.action_repeat, (64, 64),
                             grayscale=False,
                             life_done=True,
                             sticky_actions=True)
        env = wrappers.OneHotAction(env)
    elif suite == 'gym':
        if index == 0 or index is None:  #first index is always real world
            env = wrappers.GymControl(task)
        else:
            env = wrappers.GymControl(task, dr=config.dr)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)

    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
예제 #6
0
def make_gridworld_env(config, writer, prefix, datadir, store, desc=None):
    ##NOTE:CHANEG:new function
    # dreamer设计的env接口长什么样

    suite, task = config.task.split('_', 1)
    if suite == 'gridworld':
        env = wrappers.GridWorld(desc)
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    # collect here
    env = wrappers.RewardObs(env)
    return env