Пример #1
0
 def __init__(self, wrapped_env):
     self.__wrapped_env = wrapped_env
     observation_space = wrapped_env.observation_space
     action_space = wrapped_env.action_space
     self.observation_space = convert_gym_space(observation_space)
     self.action_space = convert_gym_space(action_space)
     self.spec = EnvSpec(self.observation_space, self.action_space)
Пример #2
0
def run_task(args,*_):
    
    #env = TfEnv(normalize(dnc_envs.create_stochastic('pick'))) # Cannot be solved easily by TRPO
    metaworld_env = ML1.get_train_tasks("pick-place-v1")
    tasks = metaworld_env.sample_tasks(1)
    metaworld_env.set_task(tasks[0])
    metaworld_env._observation_space = convert_gym_space(metaworld_env.observation_space)
    metaworld_env._action_space = convert_gym_space(metaworld_env.action_space)
    env = TfEnv(normalize(metaworld_env))

    policy = GaussianMLPPolicy(
        name="policy",
        env_spec=env.spec,
        min_std=1e-2,
        hidden_sizes=(150, 100, 50),
    )

    baseline = LinearFeatureBaseline(env_spec=env.spec)

    algo = TRPO(
        env=env,
        policy=policy,
        baseline=baseline,
        #batch_size=50000,
        batch_size=100,
        force_batch_sampler=True,
        max_path_length=50,
        discount=1,
        step_size=0.02,
    )
    
    algo.train()
    def __init__(self,
                 env_name,
                 record_video=True,
                 video_schedule=None,
                 log_dir=None,
                 record_log=True,
                 force_reset=False,
                 screen_width=84,
                 screen_height=84):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        if 'Doom' in env_name:
            from ppaquette_gym_doom.wrappers.action_space import ToDiscrete
            wrapper = ToDiscrete('minimal')
            env = wrapper(env)

        self.env = env
        self.env_id = env.spec.id

        monitor_manager.logger.setLevel(logging.WARNING)

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env,
                                            log_dir,
                                            video_callable=video_schedule,
                                            force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
        self._force_reset = force_reset
        self.screen_width = screen_width
        self.screen_height = screen_height
        self._observation_space = Box(low=0,
                                      high=1,
                                      shape=(screen_width, screen_height, 1))
Пример #4
0
    def __init__(self,
                 env_name,
                 wrappers=(),
                 wrapper_args=(),
                 record_video=True,
                 video_schedule=None,
                 log_dir=None,
                 record_log=True,
                 post_create_env_seed=None,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        if post_create_env_seed is not None:
            env.set_env_seed(post_create_env_seed)
        for i, wrapper in enumerate(wrappers):
            if wrapper_args and len(wrapper_args) == len(wrappers):
                env = wrapper(env, **wrapper_args[i])
            else:
                env = wrapper(env)
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            # self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.env = CustomGymMonitorEnv(self.env,
                                           log_dir,
                                           video_callable=video_schedule,
                                           force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags.get(
            'wrapper_config.TimeLimit.max_episode_steps')
        self._log_dir = log_dir
        self._force_reset = force_reset
Пример #5
0
    def __init__(self,
                 wrapped_gym_env,
                 record_video=True,
                 video_schedule=None,
                 log_dir=None,
                 record_log=True,
                 force_reset=False):

        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured."
                )
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = wrapped_gym_env
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env,
                                            log_dir,
                                            video_callable=video_schedule,
                                            force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags[
            'wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset
Пример #6
0
def _setup_model(env, new_reward, tf_cfg):
    env_spec = EnvSpec(observation_space=to_tf_space(
        convert_gym_space(env.observation_space)),
                       action_space=to_tf_space(
                           convert_gym_space(env.action_space)))
    model_cfg, reward_params = new_reward
    infer_graph = tf.Graph()
    with infer_graph.as_default():
        model_kwargs = dict(model_cfg)
        model_cls = model_kwargs.pop('model')
        irl_model = model_cls(env_spec=env_spec,
                              expert_trajs=None,
                              **model_kwargs)
        if model_cls == AIRLStateOnly:
            reward_var = irl_model.reward
        elif model_cls == AIRLStateAction:
            reward_var = irl_model.energy
        else:
            assert False, "Unsupported model type"
        sess = tf.Session(config=tf_cfg)
        with sess.as_default():
            irl_model.set_params(reward_params)
    return sess, irl_model, reward_var
Пример #7
0
 def __init__(self, venv):
     self.venv = venv
     self._observation_space = convert_gym_space(venv.observation_space)
     self._action_space = convert_gym_space(venv.action_space)
Пример #8
0
    def __init__(self, env_name, register_info=None, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=True, screen_width=84, screen_height=84, frame_skip=1, doom_actionspace='Box',
                 conv=True, client_port=10000, transpose_output=False, stack_frames=False, stack_size=4):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())
        if 'Doom' in env_name:
            import ex2.envs.doom
        if 'Minecraft' in env_name:
            import axe.envs.minecraft

        if register_info:
            try:
                gym.envs.register(**register_info)
            except gym.error.Error:
                traceback.print_exc()

        env = gym.envs.make(env_name)

        if 'Doom' in env_name:
            from ex2.envs.doom.wrappers import SetResolution
            from ex2.envs.doom.wrappers.action_space import ToDiscrete, ToBox
            if doom_actionspace == 'Box':
                wrapper1 = ToBox('minimal')
            else:
                wrapper1 = ToDiscrete('minimal')
            #lock = multiprocessing.Lock()
            #env.configure(lock=lock)
            wrapper2 = SetResolution('160x120')
            env = wrapper2(wrapper1(env))
        if 'Minecraft' in env_name:
            env.init(videoResolution=[screen_width, screen_height], allowContinuousMovement=["move", "turn"],
                     continuous_discrete=False, vision=False,
                     client_pool=[('127.0.0.1', client_port)])

        self.env = env
        self.env_id = env.spec.id
        self.env_name = env_name
        self.frame_skip = frame_skip
        self.stack_frames = stack_frames
        if stack_frames:
            self.channel_size = stack_size
        else:
            self.channel_size = 3

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True


        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
        self._log_dir = log_dir
        self._force_reset = force_reset
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.conv = conv
        self.transpose_output = transpose_output
        if conv:
            if self.transpose_output:
                self._observation_space = Box(low=0, high=1, shape=(self.channel_size, screen_width, screen_height))
                #self._observation_space = Box(low=0, high=1, shape=(3* screen_width* screen_height))
            else:
                self._observation_space = Box(low=0, high=1, shape=(screen_width, screen_height, self.channel_size))
        else:
            self._observation_space = Box(low=0, high=1, shape=(self.channel_size,))
        self.last_info = None
        self.last_obs = []
Пример #9
0
 def __init__(self, env):
     Serializable.quick_init(self, locals())
     self.env = env
     self._observation_space = convert_gym_space(env.observation_space)
     self._action_space = convert_gym_space(env.action_space)
     self._horizon = 500
Пример #10
0
 def observation_space(self):
     return convert_gym_space(self.conopt_env.observation_space)