示例#1
0
文件: mt50.py 项目: jesbu1/spinningup
    def __init__(self, env_type='train', sample_all=False):
        assert env_type == 'train' or env_type == 'test'
        Serializable.quick_init(self, locals())

        cls_dict = {}
        args_kwargs = {}
        for k in HARD_MODE_CLS_DICT.keys():
            for task in HARD_MODE_CLS_DICT[k].keys():
                cls_dict[task] = HARD_MODE_CLS_DICT[k][task]
                args_kwargs[task] = HARD_MODE_ARGS_KWARGS[k][task]
        assert len(cls_dict.keys()) == 50

        super().__init__(
            task_env_cls_dict=cls_dict,
            task_args_kwargs=args_kwargs,
            sample_goals=False,
            obs_type='with_goal_id',
            sample_all=sample_all,)

        goals_dict = {
            t: [e.goal.copy()]
            for t, e in zip(self._task_names, self._task_envs)
        }

        self.discretize_goal_space(goals_dict)
        assert self._fully_discretized
示例#2
0
文件: ml1.py 项目: jesbu1/spinningup
    def __init__(self,
                 task_name,
                 env_type='train',
                 n_goals=50,
                 sample_all=False):
        assert env_type == 'train' or env_type == 'test'
        Serializable.quick_init(self, locals())

        if task_name in HARD_MODE_CLS_DICT['train']:
            cls_dict = {task_name: HARD_MODE_CLS_DICT['train'][task_name]}
            args_kwargs = {
                task_name: HARD_MODE_ARGS_KWARGS['train'][task_name]
            }
        elif task_name in HARD_MODE_CLS_DICT['test']:
            cls_dict = {task_name: HARD_MODE_CLS_DICT['test'][task_name]}
            args_kwargs = {task_name: HARD_MODE_ARGS_KWARGS['test'][task_name]}
        else:
            raise NotImplementedError

        super().__init__(task_env_cls_dict=cls_dict,
                         task_args_kwargs=args_kwargs,
                         sample_goals=True,
                         obs_type='plain',
                         sample_all=sample_all)

        goals = self.active_env.sample_goals_(n_goals)
        self.discretize_goal_space({task_name: goals})
示例#3
0
 def __init__(self,
              task_env_cls=None,
              task_args=None,
              task_kwargs=None,):
     Serializable.quick_init(self, locals())
     self._task_envs = [
         task_env_cls(*t_args, **t_kwargs)
         for t_args, t_kwargs in zip(task_args, task_kwargs)
     ]
     self._active_task = None
示例#4
0
    def __init__(self, env_type='train', sample_all=False):
        assert env_type == 'train' or env_type == 'test'
        Serializable.quick_init(self, locals())

        cls_dict = DEBUG_MODE_CLS_DICT[env_type]
        args_kwargs = DEBUG_MODE_ARGS_KWARGS[env_type]

        super().__init__(task_env_cls_dict=cls_dict,
                         task_args_kwargs=args_kwargs,
                         sample_goals=True,
                         obs_type='plain',
                         sample_all=sample_all)
示例#5
0
    def __init__(
        self,
        task_env_cls_dict,
        task_args_kwargs,
        sample_all=True,
        sample_goals=False,
        obs_type='plain',
        repeat_times=1,
    ):
        Serializable.quick_init(self, locals())
        super().__init__(task_env_cls_dict, task_args_kwargs, sample_all,
                         sample_goals, obs_type)

        self.train_mode = True
        self.repeat_times = repeat_times
示例#6
0
    def __init__(
        self,
        task_env_cls_dict,
        task_args_kwargs,
        sample_all=True,
        sample_goals=False,
        obs_type='plain',
    ):
        Serializable.quick_init(self, locals())
        assert len(task_env_cls_dict.keys()) == len(task_args_kwargs.keys())
        assert len(task_env_cls_dict.keys()) >= 1
        for k in task_env_cls_dict.keys():
            assert k in task_args_kwargs

        self._task_envs = []
        self._task_names = []
        self._sampled_all = sample_all
        self._sample_goals = sample_goals
        self._obs_type = obs_type

        for task, env_cls in task_env_cls_dict.items():
            task_args = task_args_kwargs[task]['args']
            task_kwargs = task_args_kwargs[task]['kwargs']
            task_env = env_cls(*task_args, **task_kwargs)

            # this multitask env only accept plain observations
            # since it handles all the observation augmentations
            assert task_env.obs_type == 'plain'
            self._task_envs.append(task_env)
            self._task_names.append(task)

        # If key (taskname) is in this `self._discrete_goals`, then this task are seen
        # to be using a discrete goal space. This wrapper will
        # set the property discrete_goal_space as True, update the goal_space
        # and the sample_goals method will sample from a discrete space.
        self._discrete_goals = dict()
        self._env_discrete_index = {
            task: i
            for i, task in enumerate(self._task_names)
        }
        self._fully_discretized = True if not sample_goals else False
        self._n_discrete_goals = len(task_env_cls_dict.keys())
        self._active_task = 0
        self._check_env_list()
示例#7
0
文件: mt10.py 项目: jesbu1/spinningup
    def __init__(self, env_type='train', sample_all=False):
        assert env_type == 'train' or env_type == 'test'
        Serializable.quick_init(self, locals())

        super().__init__(
            task_env_cls_dict=EASY_MODE_CLS_DICT,
            task_args_kwargs=EASY_MODE_ARGS_KWARGS,
            sample_goals=False,
            obs_type='with_goal_id',
            sample_all=sample_all,
        )

        goals_dict = {
            t: [e.goal.copy()]
            for t, e in zip(self._task_names, self._task_envs)
        }

        self.discretize_goal_space(goals_dict)
        assert self._fully_discretized
示例#8
0
 def __init__(
         self,
         env,
         obs_means=None,
         obs_stds=None,
         obs_to_normalize_keys=['observation'],
 ):
     # self._wrapped_env needs to be called first because
     # Serializable.quick_init calls getattr, on this class. And the
     # implementation of getattr (see below) calls self._wrapped_env.
     # Without setting this first, the call to self._wrapped_env would call
     # getattr again (since it's not set yet) and therefore loop forever.
     # Or else serialization gets delegated to the wrapped_env. Serialize
     # this env separately from the wrapped_env.
     self._wrapped_env = env
     Serializable.quick_init(self, locals())
     ProxyEnv.__init__(self, env)
     self._should_normalize = not (obs_means is None and obs_stds is None)
     num_obs_types = len(obs_to_normalize_keys)
     if self._should_normalize:
         if obs_means is None:
             obs_means = dict()
             for key in self.obs_to_normalize_keys:
                 obs_means[key] = np.zeros_like(env.observation_space[key].low)
         else:
             obs_means = dict()
             for key in self.obs_to_normalize_keys:
                 obs_means[key] = np.array(obs_means[key])
         if obs_stds is None:
             obs_stds = dict()
             for key in self.obs_to_normalize_keys:
                 obs_stds[key] = np.zeros_like(env.observation_space[key].low)
         else:
             obs_stds = dict()
             for key in self.obs_to_normalize_keys:
                 obs_stds[key] = np.array(obs_stds[key])
     self._obs_means = obs_means
     self._obs_stds = obs_stds
     ub = np.ones(self._wrapped_env.action_space.shape)
     self.action_space = Box(-1 * ub, ub)
     self.obs_to_normalize_keys=obs_to_normalize_keys
示例#9
0
    def __init__(self,
                 env_type='train',
                 n_tasks=2,
                 randomize_tasks=True,
                 sample_all=False):
        self._serializable_initialized = True
        assert env_type == 'train' or env_type == 'test'
        Serializable.quick_init(self, locals())
        hard_cls_dict = dict(train=dict(
            (k, HARD_MODE_CLS_DICT['train'][k])
            for k in ('push-v1', 'button-press-v1', 'sweep-into-v1',
                      'plate-slide-v1')),
                             test=dict((m, HARD_MODE_CLS_DICT['train'][m])
                                       for m in ('coffee-button-v1',
                                                 'drawer-close-v1')))
        hard_args_dict = dict(train=dict(
            (k, HARD_MODE_ARGS_KWARGS['train'][k])
            for k in ('push-v1', 'button-press-v1', 'sweep-into-v1',
                      'plate-slide-v1')),
                              test=dict((m, HARD_MODE_ARGS_KWARGS['train'][m])
                                        for m in ('coffee-button-v1',
                                                  'drawer-close-v1')))
        cls_dict = hard_cls_dict[env_type]
        args_kwargs = hard_args_dict[env_type]

        super().__init__(task_env_cls_dict=cls_dict,
                         task_args_kwargs=args_kwargs,
                         sample_goals=True,
                         obs_type='plain',
                         sample_all=sample_all)

        ##self._max_plain_dim = 9
        #ML1.__init__(self, task_name=task_name, env_type=env_type, n_goals=50)
        #def initsample(self, n_tasks,randomize_tasks=True):
        self.tasks = self.sample_tasks(n_tasks)
        self.reset_task(0)