Example #1
0
def rollout(env,
            policy,
            path_length,
            callback=None,
            render_mode=None,
            break_on_terminal=True):
    observation_space = env.observation_space
    action_space = env.action_space

    pool = replay_pools.SimpleReplayPool(observation_space,
                                         action_space,
                                         max_size=path_length)
    sampler = simple_sampler.SimpleSampler(max_path_length=path_length,
                                           min_pool_size=None,
                                           batch_size=None)

    sampler.initialize(env, policy, pool)

    images = []
    infos = []

    t = 0
    for t in range(path_length):
        observation, reward, terminal, info = sampler.sample()
        infos.append(info)

        if callback is not None:
            callback(observation)

        if render_mode is not None:
            if render_mode == 'rgb_array':
                #note: this will only work for mujoco-py environments
                if hasattr(env.unwrapped, 'imsize'):
                    imsize = env.unwrapped.imsize
                else:
                    imsize = 200
                image = env.unwrapped.sim.render(imsize, imsize)
                #image = env.render(mode=render_mode)
                images.append(image)
            else:
                raise NotImplementedError
                # env.render()

        if terminal:
            policy.reset()
            if break_on_terminal: break

    assert pool._size == t + 1

    path = pool.batch_by_indices(np.arange(pool._size),
                                 observation_keys=getattr(
                                     env, 'observation_keys', None))
    path['infos'] = infos

    if render_mode == 'rgb_array':
        path['images'] = np.stack(images, axis=0)

    return path
Example #2
0
def rollout(environment,
            policy,
            path_length,
            sampler_class=simple_sampler.SimpleSampler,
            callback=None,
            render_kwargs=None,
            break_on_terminal=True):
    pool = replay_pools.SimpleReplayPool(environment, max_size=path_length)
    sampler = sampler_class(
        environment=environment,
        policy=policy,
        pool=pool,
        max_path_length=path_length)

    render_mode = (render_kwargs or {}).get('mode', None)
    if render_mode == 'rgb_array':
        render_kwargs = {
            **DEFAULT_PIXEL_RENDER_KWARGS,
            **render_kwargs
        }
    elif render_mode == 'human':
        render_kwargs = {
            **DEFAULT_HUMAN_RENDER_KWARGS,
            **render_kwargs
        }
    else:
        render_kwargs = None

    images = []
    infos = defaultdict(list)

    t = 0
    for t in range(path_length):
        observation, reward, terminal, info = sampler.sample()
        for key, value in info.items():
            infos[key].append(value)

        if callback is not None:
            callback(observation)

        if render_kwargs:
            image = environment.render(**render_kwargs)
            images.append(image)

        if terminal:
            policy.reset()
            if break_on_terminal: break

    assert pool._size == t + 1

    path = pool.batch_by_indices(np.arange(pool._size))
    path['infos'] = infos

    if render_mode == 'rgb_array':
        path['images'] = np.stack(images, axis=0)

    return path
Example #3
0
def my_rollout(env,
            policy,
            path_length,
            callback=None,
            render_mode=None,
            break_on_terminal=True):
    observation_space = env.observation_space
    action_space = env.action_space

    pool = replay_pools.SimpleReplayPool(
        observation_space, action_space, max_size=path_length)
    sampler = simple_sampler.SimpleSampler(
        max_path_length=path_length,
        min_pool_size=None,
        batch_size=None)

    sampler.initialize(env, policy, pool)

    images = []
    infos = []
    actions=[]
    states = []

    t = 0
    for t in range(path_length):
        observation, reward, terminal, info,action = sampler.my_sample()
        infos.append(info)
        actions.append(action)
        states.append(observation)

        if callback is not None:
            callback(observation)

        if render_mode is not None:
            if render_mode == 'rgb_array':
                image = env.render(mode=render_mode)
                images.append(image)
            else:
                env.render()

        if terminal:
            policy.reset()
            if break_on_terminal: break

    assert pool._size == t + 1

    path = pool.batch_by_indices(
        np.arange(pool._size),
        observation_keys=getattr(env, 'observation_keys', None))
    path['infos'] = infos
    path['actions'] = actions
    path['states'] = states

    if render_mode == 'rgb_array':
        path['images'] = np.stack(images, axis=0)

    return path
Example #4
0
def rollout(env,
            policy,
            path_length,
            callback=None,
            render_mode=None,
            break_on_terminal=True):

    observation_space = env.observation_space
    action_space = env.action_space

    pool = replay_pools.SimpleReplayPool(
        observation_space, action_space, max_size=path_length)
    sampler = simple_sampler.SimpleSampler(
        max_path_length=path_length,
        min_pool_size=None,
        batch_size=None)

    #env = wrappers.Monitor(env, '/home/jzchai/PycharmProjects/softlearning/examples/plotting/Synergy', force=True)

    sampler.initialize(env, policy, pool)


    images = []
    infos = []

    t = 0
    for t in range(path_length):
        observation, reward, terminal, info = sampler.sample()
        infos.append(info)

        if callback is not None:
            callback(observation)

        if render_mode is not None and render_mode !='No':
            if render_mode == 'rgb_array':
                image = env.render(mode=render_mode)
                images.append(image)
            else:
                env.render()

        if terminal:
            policy.reset()
            if break_on_terminal: break

    assert pool._size == t + 1

    path = pool.batch_by_indices(
        np.arange(pool._size),
        observation_keys=getattr(env, 'observation_keys', None))
    path['infos'] = infos

    if render_mode == 'rgb_array':
        path['images'] = np.stack(images, axis=0)

    return path
Example #5
0
def rollout(env,
            policy,
            path_length,
            sampler_class=simple_sampler.SimpleSampler,
            sampler_kwargs=None,
            callback=None,
            render_kwargs=None,
            break_on_terminal=True):
    pool = replay_pools.SimpleReplayPool(env, max_size=path_length)
    if sampler_kwargs:
        sampler = sampler_class(
            max_path_length=path_length,
            min_pool_size=None,
            batch_size=None,
            **sampler_kwargs)
    else:
        sampler = sampler_class(
            max_path_length=path_length,
            min_pool_size=None,
            batch_size=None)

    sampler.initialize(env, policy, pool)

    render_mode = (render_kwargs or {}).get('mode', None)
    if render_mode == 'rgb_array':
        render_kwargs = {
            **DEFAULT_PIXEL_RENDER_KWARGS,
            **render_kwargs
        }
    elif render_mode == 'human':
        render_kwargs = {
            **DEFAULT_HUMAN_RENDER_KWARGS,
            **render_kwargs
        }
    else:
        render_kwargs = None

    images = []
    infos = defaultdict(list)
    t = 0
    for t in range(path_length):
        observation, reward, terminal, info = sampler.sample()
        for key, value in info.items():
            infos[key].append(value)

        if callback is not None:
            callback(observation)

        if render_kwargs:
            if render_mode == 'rgb_array':
                #note: this will only work for mujoco-py environments
                if hasattr(env.unwrapped, 'imsize'):
                    imsize = env.unwrapped.imsize
                else:
                    imsize = 200

                imsize_flat = imsize*imsize*3
                #for goal conditioned stuff
                #if observation['observations'].shape[0] == 2*imsize_flat:
                #    image1 = observation['observations'][:imsize_flat].reshape(48,48,3)
                #    image2 = observation['observations'][imsize_flat:].reshape(48,48,3)
                #    image1 = (image1*255.0).astype(np.uint8)
                #    image2 = (image2*255.0).astype(np.uint8)
                #    image = np.concatenate([image1, image2], axis=1)

                if 'pixels' in observation.keys() and observation['pixels'].shape[-1] == 6:
                    pixels = observation['pixels']
                    image1 = pixels[:, :, :3]
                    image2 = pixels[:, :, 3:]
                    image = np.concatenate([image1, image2], axis=1)
                else:
                    image = env.render(**render_kwargs)
                images.append(image)
            else:
                image = env.render(**render_kwargs)
                images.append(image)

        if terminal:
            policy.reset()
            if break_on_terminal: break

    assert pool._size == t + 1

    path = pool.batch_by_indices(np.arange(pool._size))
    path['infos'] = infos

    if render_mode == 'rgb_array':
        path['images'] = np.stack(images, axis=0)

    return path