Exemplo n.º 1
0
    def _get_viewer(self, mode):
        if self.viewer[mode] is None:
            self.viewer[mode] = DmControlViewer(
                self.pixels.shape[1], self.pixels.shape[0],
                self.render_mode_list[mode]['render_kwargs']['depth'])
        return self.viewer[mode]


class DmControlInterface(EnvInterface):
    def env_trans_fn(self, env, set_eval):
        return DmControlWrapper(env, self.args.time_limit)

    def create_from_id(self, env_id):
        # Must be in the format dm.domain.task
        _, domain, task = env_id.split('.')
        try:
            task_kwargs = None
            if self.args.time_limit is not None:
                task_kwargs = {'time_limit': self.args.time_limit}
            env = suite.load(domain_name=domain,
                             task_name=task,
                             task_kwargs=task_kwargs)
            return env
        except NameError as e:
            print('DeepMind Control Suite is not installed')
            raise e


register_env_interface("^dm\.", DmControlInterface)
Exemplo n.º 2
0
    def step(self, a):
        obs, reward, done, info = super().step(a)
        #print(info['player_data'])
        return obs, reward, done, info


class PokerInterface(EnvInterface):
    def get_add_args(self, parser):
        parser.add_argument('--poker-n-players', type=int, default=2, help="""
                Including the agent.
                """)

    def create_from_id(self, env_id):
        stack = 500
        env = HoldemTable(initial_stacks=stack)
        for _ in range(self.args.poker_n_players-1):
            player = Player()
            env.add_player(player)

        player = Player()
        player.autoplay = False
        env.add_player(player)
        env.reset()

        env = PokerAdapter(env)

        return env

register_env_interface("Poker-v0", PokerInterface)
Exemplo n.º 3
0
            reward = -np.sum(np.square(self.state - self.goal))
        obs = {
            'observation': np.copy(self.state),
            'achieved_goal': np.copy(self.state),
            'desired_goal': np.copy(self.goal),
        }

        if self.n_steps >= self.n:
            done = True

        info = {}
        if done:
            info['ep_success'] = float(np.array_equal(self.state, self.goal))
        return obs, reward, done, info

    def render(self):
        print("\rstate :", np.array_str(self.state), end=' ' * 10)


class BitFlipInterface(EnvInterface):
    def get_add_args(self, parser):
        parser.add_argument('--bit-flip-n', type=int, default=5)
        parser.add_argument('--bit-flip-reward', type=str, default='sparse')

    def create_from_id(self, env_id):
        return BitFlipEnv(self.args.bit_flip_n, self.args.bit_flip_reward)


# Match any version
register_env_interface(BIT_FLIP_ID.split('-')[0], BitFlipInterface)
Exemplo n.º 4
0
class MinigridInterface(EnvInterface):
    def create_from_id(self, env_id):
        env = gym.make(env_id)
        if self.args.gw_mode == 'flat':
            env = FlatGrid(env, self.args.gw_card_dirs)
        elif self.args.gw_mode == 'img':
            env = FullFlatGrid(FullyObsWrapper(env), self.args.gw_card_dirs)
        else:
            raise ValueError()

        if self.args.gw_goal_info:
            env = DirectionObsWrapper(env)
        return env

    def get_add_args(self, parser):
        parser.add_argument('--gw-mode',
                            type=str,
                            default='flat',
                            help="""
                Options are: [flat,img]
        """)
        parser.add_argument('--gw-card-dirs',
                            action='store_true',
                            default=False)
        parser.add_argument('--gw-goal-info',
                            action='store_true',
                            default=False)


register_env_interface("^MiniGrid", MinigridInterface)
Exemplo n.º 5
0
import gym
from rlf.envs.env_interface import EnvInterface, register_env_interface
BIT_FLIP_ID = 'BitFlip-v0'


class BlackJackWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(shape=(3, ), high=100, low=0)

    def observation(self, observation):
        return [*observation[:2], int(observation[2])]


class BlackJackInterface(EnvInterface):
    def create_from_id(self, env_id):
        env = super().create_from_id(env_id)
        return BlackJackWrapper(env)


# Match any version
register_env_interface('Blackjack', BlackJackInterface)
Exemplo n.º 6
0
            env.action_spec = env.env.action_spec

        # Only render to camera when using as an evaluation environment.
        env = GymWrapper(env)
        env.reward_range = None
        env.metadata = None
        env.spec = None

        return RoboSuiteWrapper(env)

    def get_add_args(self, parser):
        parser.add_argument('--rs-reward-shaping', type=str2bool, default=True)
        parser.add_argument('--rs-demo-path')
        parser.add_argument('--rs-demo-wrapper',
                            action='store_true',
                            default=False)

    def create_from_id(self, env_id):
        _, task = env_id.split(".")
        # Env interface will do the actual job of creating the environment.
        env = robosuite.make(task,
                             has_offscreen_renderer=set_eval,
                             has_renderer=False,
                             use_object_obs=True,
                             reward_shaping=self.args.rs_reward_shaping,
                             use_camera_obs=set_eval)
        return env


register_env_interface("^rs\.", RoboSuiteControlInterface)
Exemplo n.º 7
0
    def _get_cur_state(self, obs):
        return (self.env.grid, self.env.agent_pos, self.env.agent_dir)


class DoublePlaybackEnvInterface(EnvInterfaceWrapper):
    def __init__(self, args):
        env = super().__init__(args, MinigridInterface)

    def create_from_id(self, env_id):
        env = super().create_from_id(env_id)
        env = MinigridPlaybackWrapper(env)
        return env


register_env_interface("^MiniGrid", DoublePlaybackEnvInterface)

#class DoublePlayUpdater(NestedAlgo):
#    def __init__(self, updater1, updater2):
#        super().__init__([updater1, updater2], 0)
#
#    def get_storage_buffer(self, policy, envs, args):
#        return SwitchingNestedStorage(
#                self.modules[0].get_storage_buffer(policy, envs, args),
#                self.modules[1].get_storage_buffer(policy, envs, args))
#
#
#class SwitchingNestedStorage(NestedStorage):
#    def __init__(self, main_storage, back_storage):
#        super().__init__({
#            'main': main_storage,