def get_skillset(self): # print(self.skill_set) # print(self.actions) to_return = DictTree({ skill: DictTree( step=None, model_name=getattr(skill_class, 'model_name', "log_poly2"), arg_in_len=skill_class.arg_in_len, max_cnt=getattr(skill_class, 'max_cnt', None), sub_skill_names=getattr(skill_class, 'sub_skill_names', []), ret_out_len=skill_class.ret_out_len, min_valid_data=getattr(skill_class, 'min_valid_data', None), sub_arg_accuracy=getattr(skill_class, 'sub_arg_accuracy', None), ) for skill, skill_class in list(self.skill_set.items()) + list(self.actions.items()) }) for skill in to_return.values(): if skill.sub_skill_names: skill.ret_in_len = max( to_return[sub_skill_name].ret_out_len for sub_skill_name in skill.sub_skill_names) skill.arg_out_len = max( skill.ret_out_len, max(to_return[sub_skill_name].arg_in_len for sub_skill_name in skill.sub_skill_names)) self.stack = None self.last_act_name = None return to_return
class HierarchicalAgent(agent.Agent): def __init__(self, config): super(HierarchicalAgent, self).__init__(config) self.skillset = DictTree({ skill.__name__: DictTree( step=getattr(skill, 'step', None) if config.rollable else None, model_name=getattr(skill, 'model_name', self.default_model_name), arg_in_len=skill.arg_in_len, max_cnt=getattr(skill, 'max_cnt', None), sub_skill_names=getattr(skill, 'sub_skill_names', []), ret_out_len=skill.ret_out_len, min_valid_data=getattr(skill, 'min_valid_data', None), sub_arg_accuracy=getattr(skill, 'sub_arg_accuracy', None), ) for skill in self.skills + self.actions }) for skill in self.skillset.values(): if skill.sub_skill_names: skill.ret_in_len = max( self.skillset[sub_skill_name].ret_out_len for sub_skill_name in skill.sub_skill_names) skill.arg_out_len = max( skill.ret_out_len, max(self.skillset[sub_skill_name].arg_in_len for sub_skill_name in skill.sub_skill_names)) if config.rollable and not config.teacher: for skill_name, skill in self.skillset.items(): if skill.sub_skill_names: skill.step = load_skill(config.model_dirname, skill_name, skill) self.stack = None self.last_act_name = None @property def root_skill_name(self): raise NotImplementedError @property def skills(self): raise NotImplementedError @property def actions(self): raise NotImplementedError @property def default_model_name(self): raise NotImplementedError def reset(self, init_arg): self.stack = [DictTree(name=self.root_skill_name, arg=init_arg, cnt=0)] self.last_act_name = None def step(self, obs): ret_name = self.last_act_name ret_val = obs steps = [] while self.stack: top = self.stack[-1] sub_name, sub_arg = self.skillset[top.name].step( top.arg, top.cnt, ret_name, ret_val, obs) steps.append( DictTree( name=top.name, arg=top.arg, cnt=top.cnt, ret_name=ret_name, ret_val=ret_val, sub_name=sub_name, sub_arg=sub_arg, )) print("{}({}, {}, {}, {}) -> {}({})".format( top.name, top.arg, top.cnt, ret_name, ret_val, sub_name, sub_arg)) if sub_name is None: self.stack.pop() ret_name = top.name ret_val = sub_arg elif self.skillset[sub_name.replace('Record_', '')].sub_skill_names: top.cnt += 1 self.stack.append(DictTree(name=sub_name, arg=sub_arg, cnt=0)) ret_name = None ret_val = None else: top.cnt += 1 self.last_act_name = sub_name return sub_name, sub_arg, DictTree(steps=steps) self.last_act_name = None return None, None, DictTree(steps=steps)