Exemplo n.º 1
0
def get_overcooked_obj_attr(attr,
                            env=None,
                            mdp=None,
                            env_params=None,
                            mdp_params=None):
    """
    returns overcooked object attribute based on its name; used mostly to get state processing (encoding) functions and gym spaces
    when receives string parse it to get attribute; format is "env"/"mdp" + "." + method name i.e "env.lossless_state_encoding_mdp"
    also support dicts (where replaces strings in values with object attributes)
    when receives method/function returns original method; this obviously does not work this way if attr is str/dict
    """
    attr_type = type(attr)
    if attr_type is str:
        name = attr
        [obj_name, attr_name] = name.split(".")
        if obj_name == "mdp":
            if not mdp:
                if env:
                    mdp = env.mdp
                else:
                    mdp = OvercookedGridworld(**mdp_params)
            attr = getattr(mdp, attr_name)
        elif obj_name == "env":
            if not env:
                if not mdp:
                    mdp = OvercookedGridworld(**mdp_params)
                env_params = only_valid_named_args(env_params,
                                                   OvercookedEnv.from_mdp)
                env = OvercookedEnv.from_mdp(mdp, **env_params)
            attr = getattr(env, attr_name)
        # not tested or used anywhere yet
        # elif obj_name in kwargs:
        #     attr = getattr(kwargs[obj_name], attr_name)
        else:
            raise ValueError("Unsupported obj attr string " + name)
    elif attr_type is dict:
        attr = {
            k: get_overcooked_obj_attr(v,
                                       env=env,
                                       mdp=mdp,
                                       env_params=env_params,
                                       mdp_params=mdp_params)
            for k, v in attr.items()
        }
    # not tested or used anywhere yet
    # elif attr_type in [list, tuple]:
    #     attr = attr_type(get_overcooked_obj_attr(elem, env=env, mdp=mdp, env_params=env_params,
    #         mdp_params=mdp_params) for elem in attr)
    return attr
Exemplo n.º 2
0
 def from_trajectories_json(trajs, idx=0):
     mdp = OvercookedGridworld(**trajs["mdp_params"][idx])
     env_params = only_valid_named_args(trajs["env_params"][idx], OvercookedEnv.from_mdp) 
     return OvercookedEnv.from_mdp(mdp, **env_params)