示例#1
0
    def __init__(self, env_type="train", sample_all=False, task_name=None):
        del env_type

        cls_dict = {}
        args_kwargs = {}
        for k in HARD_MODE_CLS_DICT.keys():
            for task in HARD_MODE_CLS_DICT[k].keys():
                cls_dict[task] = HARD_MODE_CLS_DICT[k][task]
                args_kwargs[task] = HARD_MODE_ARGS_KWARGS[k][task]
        assert len(cls_dict.keys()) == 50
        if task_name is not None:
            if task_name not in cls_dict:
                raise ValueError("{} does not exist in MT50 tasks".format(
                    task_name))
            cls_dict = {task_name: cls_dict[task_name]}
            args_kwargs = {task_name: args_kwargs[task_name]}

        super().__init__(
            task_env_cls_dict=cls_dict,
            task_args_kwargs=args_kwargs,
            sample_goals=False,
            obs_type='with_goal_id',
            sample_all=sample_all,)

        goals_dict = {
            t: [e.goal.copy()]
            for t, e in zip(self._task_names, self._task_envs)
        }

        self.discretize_goal_space(goals_dict)
        assert self._fully_discretized
示例#2
0
文件: mt50.py 项目: jesbu1/spinningup
    def __init__(self, env_type='train', sample_all=False):
        assert env_type == 'train' or env_type == 'test'
        Serializable.quick_init(self, locals())

        cls_dict = {}
        args_kwargs = {}
        for k in HARD_MODE_CLS_DICT.keys():
            for task in HARD_MODE_CLS_DICT[k].keys():
                cls_dict[task] = HARD_MODE_CLS_DICT[k][task]
                args_kwargs[task] = HARD_MODE_ARGS_KWARGS[k][task]
        assert len(cls_dict.keys()) == 50

        super().__init__(
            task_env_cls_dict=cls_dict,
            task_args_kwargs=args_kwargs,
            sample_goals=False,
            obs_type='with_goal_id',
            sample_all=sample_all,)

        goals_dict = {
            t: [e.goal.copy()]
            for t, e in zip(self._task_names, self._task_envs)
        }

        self.discretize_goal_space(goals_dict)
        assert self._fully_discretized
示例#3
0
def generate_mt50_env(mt_param):
    from metaworld.envs.mujoco.env_dict import HARD_MODE_CLS_DICT, HARD_MODE_ARGS_KWARGS
    cls_dict = {}
    args_kwargs = {}
    for k in HARD_MODE_CLS_DICT.keys():
        for task in HARD_MODE_CLS_DICT[k].keys():
            cls_dict[task] = HARD_MODE_CLS_DICT[k][task]
            args_kwargs[task] = HARD_MODE_ARGS_KWARGS[k][task]

    if "random_init" in mt_param:
        for key in args_kwargs:
            args_kwargs[key]["kwargs"]["random_init"]=mt_param["random_init"]

    return generate_mt_env(cls_dict, args_kwargs, **mt_param), \
        cls_dict, args_kwargs