Example #1
0
def make_env(config, writer, prefix, datadir, store):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        env = wrappers.DeepMindControl(task)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == 'gym':
        env = wrappers.Gym(task, config, size=(128, 128))
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif task == 'door':
        env = wrappers.DoorOpen(config, size=(128, 128))
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif task == 'drawer':
        env = wrappers.DrawerOpen(config, size=(128, 128))
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    if prefix == 'test':
        callbacks.append(
            lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Example #2
0
def make_env(config, writer, prefix, datadir, store):
    if "lunar" in config.task:
        env = wrappers.LunarLander(size=(64, 64),
                                   action_repeat=config.action_repeat)
    elif "Car" in config.task:
        env = wrappers.CarEnvWrapper(env_name=config.task,
                                     size=(64, 64),
                                     action_repeat=config.action_repeat,
                                     seed=config.seed)
    else:
        suite, task = config.task.split('_', 1)
        if suite == 'dmc':
            env = wrappers.DeepMindControl(task)
            env = wrappers.ActionRepeat(env, config.action_repeat)
            env = wrappers.NormalizeActions(env)
        elif suite == 'atari':
            env = wrappers.Atari(task,
                                 config.action_repeat, (64, 64),
                                 grayscale=False,
                                 life_done=True,
                                 sticky_actions=True)
            env = wrappers.OneHotAction(env)
        else:
            raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Example #3
0
def make_env(config, writer, prefix, datadir, store):
    suite, task = config.task.split("_", 1)
    if suite == "dmc":
        env = wrappers.DeepMindControl(task)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == "atari":
        env = wrappers.Atari(
            task,
            config.action_repeat,
            (64, 64),
            grayscale=False,
            life_done=True,
            sticky_actions=True,
        )
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Example #4
0
def make_initialization_env(config):
    env = make_single_track_env('columbia',
                                action_repeat=config.action_repeat,
                                rendering=False)
    datadir = pathlib.Path('.tmp')
    writer = tf.summary.create_file_writer(str(datadir),
                                           max_queue=1000,
                                           flush_millis=20000)
    callbacks_list = []
    callbacks_list.append(
        lambda episodes: callbacks.save_episodes(datadir, episodes))
    env = wrappers.Collect(env, callbacks_list, config.precision)
    return writer, datadir, env
Example #5
0
def make_test_env(config, writer, datadir, gui=False):
    env = make_base_env(config, gui)
    env = wrappers.FixedResetMode(env, mode='grid')
    env = wrappers.TimeLimit(env,
                             config.time_limit_test / config.action_repeat)
    # rendering
    render_callbacks = []
    render_callbacks.append(
        lambda videos: callbacks.save_videos(videos, config, datadir))
    env = wrappers.Render(env, render_callbacks)
    # summary
    callback_list = []
    callback_list.append(lambda episodes: callbacks.summarize_episode(
        episodes, config, datadir, writer, 'test'))
    env = wrappers.Collect(env, callback_list, config.precision)
    return env
Example #6
0
def make_env(config,
             writer,
             prefix,
             datadir,
             store,
             index=None,
             real_world=False):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        if config.dr is None or real_world:  #first index is always real world
            env = wrappers.DeepMindControl(task,
                                           use_state=config.use_state,
                                           real_world=real_world)
        else:
            env = wrappers.DeepMindControl(task,
                                           dr=config.dr,
                                           use_state=config.use_state,
                                           real_world=real_world)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == 'atari':
        env = wrappers.Atari(task,
                             config.action_repeat, (64, 64),
                             grayscale=False,
                             life_done=True,
                             sticky_actions=True)
        env = wrappers.OneHotAction(env)
    elif suite == 'gym':
        if index == 0 or index is None:  #first index is always real world
            env = wrappers.GymControl(task)
        else:
            env = wrappers.GymControl(task, dr=config.dr)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)

    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env
Example #7
0
def make_train_env(config, writer, datadir, gui=False):
    env = make_base_env(config, gui)
    if env.n_agents > 1:
        env = wrappers.FixedResetMode(
            env,
            mode='random_ball')  # sample in random points close within a ball
    else:
        env = wrappers.FixedResetMode(env, mode='random')
    env = wrappers.TimeLimit(env,
                             config.time_limit_train / config.action_repeat)
    # storing and summary
    callback_list = []
    callback_list.append(
        lambda episodes: callbacks.save_episodes(datadir, episodes))
    callback_list.append(lambda episodes: callbacks.summarize_episode(
        episodes, config, datadir, writer, 'train'))
    env = wrappers.Collect(env, callback_list, config.precision)
    return env
Example #8
0
def make_gridworld_env(config, writer, prefix, datadir, store, desc=None):
    ##NOTE:CHANEG:new function
    # dreamer设计的env接口长什么样

    suite, task = config.task.split('_', 1)
    if suite == 'gridworld':
        env = wrappers.GridWorld(desc)
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    # collect here
    env = wrappers.RewardObs(env)
    return env
Example #9
0
def wrap_wrt_track(env,
                   action_repeat,
                   outdir,
                   writer,
                   track,
                   checkpoint_id,
                   save_trajectories=False):
    env = wrappers.OccupancyMapObs(env)
    render_callbacks = []
    render_callbacks.append(lambda videos: save_eval_videos(
        videos, outdir / 'videos', action_repeat, track, checkpoint_id))
    env = wrappers.Render(env, render_callbacks, follow_view=False)
    callbacks = []
    if save_trajectories:
        callbacks.append(lambda episodes: save_trajectory(
            episodes, outdir, action_repeat, track, checkpoint_id))
    callbacks.append(lambda episodes: summarize_eval_episode(
        episodes, outdir, writer, f'{track}', action_repeat))
    env = wrappers.Collect(env, callbacks)
    return env
Example #10
0
def make_env(config, writer, prefix, datadir, store):
    suite, task = config.task.split('_', 1)
    if suite == 'dmc':
        env = wrappers.DeepMindControl(task)
        env = wrappers.ActionRepeat(env, config.action_repeat)
        env = wrappers.NormalizeActions(env)
    elif suite == 'atari':
        env = wrappers.Atari(task,
                             config.action_repeat, (64, 64),
                             grayscale=False,
                             life_done=True,
                             sticky_actions=True)
        env = wrappers.OneHotAction(env)
    elif suite == 'football':
        env = football_env.create_environment(
            representation='pixels',
            env_name='academy_empty_goal_close',
            stacked=False,
            logdir='./football/empty_goal_close2',
            write_goal_dumps=True,
            write_full_episode_dumps=True,
            render=True,
            write_video=True)
        env = wrappers.Football(env)
        env = wrappers.OneHotAction(env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat)
    callbacks = []
    if store:
        callbacks.append(lambda ep: tools.save_episodes(datadir, [ep]))
    callbacks.append(
        lambda ep: summarize_episode(ep, config, datadir, writer, prefix))
    env = wrappers.Collect(env, callbacks, config.precision)
    env = wrappers.RewardObs(env)
    return env