def __init__(self, verbose, context_id, change_cb): """ @brief Manages a set of individuals with unique ID generation @param verbose The verbose @param context_id The context identifier @param change_cb The change cb """ self._id_gen = IdGen() self._change_cb = change_cb IndividualsDataset.__init__(self, verbose, context_id, init=False)
def __init__(self, verbose, context_id, change_cb): self._id_gen = IdGen() self._change_cb = change_cb IndividualsDataset.__init__(self, verbose, context_id, init=False)
class WorldModel(IndividualsDataset): """ @brief A set of individuals with unique ID generation """ def __init__(self, verbose, context_id, change_cb): self._id_gen = IdGen() self._change_cb = change_cb IndividualsDataset.__init__(self, verbose, context_id, init=False) def reset(self, add_root=True, scene_name="skiros:blank_scene"): """ @brief Initialize the scene """ IndividualsDataset.reset(self) self._id_gen.clear() if add_root: if self.has_individual(scene_name): root = self.get_individual(scene_name) else: root = Element("skiros:Scene", scene_name) self.add_element(root, self.__class__.__name__) def _remove(self, statement, author, is_relation=False): """ @brief Remove a statement from the scene """ IndividualsDataset._remove(self, statement, author, is_relation) if is_relation: self._change_cb(author, "remove", relation={ 'src': self.uri2lightstring(statement[0]), 'type': self.uri2lightstring(statement[1]), 'dst': self.uri2lightstring(statement[2]) }) def _add(self, statement, author, is_relation=False): """ @brief Add a statement to the scene """ IndividualsDataset._add(self, statement, author, is_relation) if is_relation: self._change_cb(author, "add", relation={ 'src': self.uri2lightstring(statement[0]), 'type': self.uri2lightstring(statement[1]), 'dst': self.uri2lightstring(statement[2]) }) def _uri2type(self, uri): return uri.split('-')[0] def _uri2id(self, uri): if uri.find('-') < 0: return -1 return int(uri.split('-')[1]) @synchronized def load_context(self, filename): """ @brief Load scene from file """ if filename: self._filename = filename if not path.isfile(self.filedir): log.error( "[load_context]", "Can't load scene {}. File not found. ".format(self.filename)) return self._stop_reasoners() self.reset(add_root=False) self.context.parse(self.filedir, format='turtle') individuals = self.context.query( "SELECT ?x WHERE { ?x rdf:type <http://www.w3.org/2002/07/owl#NamedIndividual>. } " ) for i in individuals: i = self.uri2lightstring(i[0]) iid = self._uri2id(i) if iid >= 0: self._id_gen.getId(iid) self._start_reasoners() log.info("[load_scene]", "Loaded scene {}. ".format(self.filename)) @synchronized def add_element(self, e, author): """ @brief Add an element to the scene """ e.setUri(self._id_gen.getId(e.getIdNumber())) IndividualsDataset.add_element(self, e, author) return e @synchronized def update_element(self, e, author): """ @brief Update an element in the scene """ if not self._id_gen.hasId(self._uri2id(e.id)): log.error( "[update_element]", "Request update from {}, but Id {} is not present in the wm. ". format(author, self._uri2id(e.id))) return IndividualsDataset.update_element(self, e, author) @synchronized def update_properties(self, e, author, reasoner=None, publish=True): """ @brief Update properties of an element in the scene """ if not self._id_gen.hasId(self._uri2id(e.id)): log.error( "[update_element]", "Request update from {}, but Id {} is not present in the wm. ". format(author, self._uri2id(e.id))) return IndividualsDataset.update_properties(self, e, author, reasoner) if publish: self._change_cb(author, "update", self.get_element(e.id)) @synchronized def remove_element(self, e, author): """ Remove an element from the scene """ if IndividualsDataset.remove_element(self, e, author): self._id_gen.removeId(self._uri2id(e.id)) return True return False
class SkillCore(SkillDescription): gen_id = IdGen() def __init__(self): """ @brief An abstract executable skill with a description (type, label, params, conditions), a state and progress code """ # Description self._id = SkillCore.gen_id.getId() self._type = "" self._label = "" self._description = SkillDescription() # Params self._params = params.ParamHandler() # Conditions self._pre_conditions = [] self._hold_conditions = [] self._post_conditions = [] # Execution self._state_change = Event() self._state = State.Uninitialized self._avg_time_keeper = TimeKeeper() self._time_keeper = TimeKeeper() self._progress_code = 0 self._progress_period = 0.0 self._progress_time = 0.0 self._progress_msg = "" self._expand_on_start = False # --------Class functions-------- def expand(self, skill): return def hasChildren(self): return False def _setState(self, state): self._state = state self._state_change.set() def _setProgress(self, msg, code=None): if code is None: code = self._progress_code + 1 self._progress_code = code self._progress_period = self._avg_time_keeper.get_avg_time() self._progress_time = self._time_keeper.time_from_start() self._progress_msg = str(msg) @property def id(self): return self._id @property def progress_code(self): return self._progress_code @property def progress_period(self): return self._progress_period @property def progress_time(self): return self._progress_time @property def progress_msg(self): return self._progress_msg @property def state(self): return self._state @property def expand_on_start(self): """ @brief Default False. If true, the skill will expand every time it is started. Used e.g. in a planner skill """ return self._expand_on_start def _resetDescription(self, other=None): if other: self._params.reset(self._description._params.merge(other._params)) else: self._params = deepcopy(self._description._params) self._pre_conditions = deepcopy(self._description._pre_conditions) self._hold_conditions = deepcopy(self._description._hold_conditions) self._post_conditions = deepcopy(self._description._post_conditions) def hasPreCond(self): return bool(self._pre_conditions) def checkPreCond(self, verbose=False): """ @brief Check pre-conditions. @param verbose (bool) Print error message when check fail @return A list of parameters that breaks the conditions, or an empty list if all are satisfied """ to_ret = list() err_msg = "" for c in self._pre_conditions: if not c.evaluate(self._params, self._wmi): err_msg += "{} Check failed. \n".format(c.getDescription()) if verbose: log.error(c.getDescription(), "ConditionCheck failed") to_ret += c.getKeys() self._setProgress(err_msg, -1) return list(set(to_ret)) def checkHoldCond(self, verbose=False): """ @brief Check hold-conditions. @param verbose (bool) Print error message when check fail @return A list of parameters that breaks the conditions, or an empty list if all are satisfied """ to_ret = list() err_msg = "" for c in self._hold_conditions: if not c.evaluate(self._params, self._wmi): err_msg += "{} Check failed. \n".format(c.getDescription()) if verbose: log.error("HoldConditionCheck failed", c.getDescription()) to_ret += c.getKeys() self._setProgress(err_msg, -2) return list(set(to_ret)) def hasPostCond(self): return bool(self._post_conditions) def checkPostCond(self, verbose=False): """ @brief Check post-conditions. @param verbose (bool) Print error message when check fail @return A list of parameters that breaks the conditions, or an empty list if all are satisfied """ to_ret = list() for c in self._post_conditions: if not c.evaluate(self._params, self._wmi): if verbose: log.error(c.getDescription(), "ConditionCheck failed") to_ret += c.getKeys() return list(set(to_ret)) # -------- Control functions-------- def preempt(self): if self.hasState(State.Running): self._setState(self.onPreempt()) if not self.onEnd(): self._setState(State.Failure) return self._state def getState(self): return self._state def hasState(self, state): return self._state == state def waitState(self, state, isset=True): if isset: # Xor? while self._state != state: # print 'Waiting set.. {}'.format(self._state) self._state_change.clear() self._state_change.wait() else: while self._state == state: # print 'Waiting not set.. {}'.format(self._state) self._state_change.clear() self._state_change.wait() # print 'State changed {}'.format(self._state) def reset(self): self.onReset() self._params.setDefault() self._time_keeper.reset() self._avg_time_keeper.reset() self._setProgress("", 0) self._setState(State.Idle) return self._state def start(self, params=None): if params: self.specifyParams(params, False) self._time_keeper.reset() if self.onStart(): self._setState(State.Running) self._setProgress("Start", 0) else: self._setState(State.Failure) return self._state def printInfo(self, verbose=False): s = "{}-{} ".format(self._type, self._label) if verbose: s += "[" s += self._params.printState() + ']\n' s += self.printConditions() else: s += "\n" return s def printState(self, verbose=False): s = "{}-{}({})".format(self.type[self.type.find(":") + 1:], self.label, self.state) if verbose: s += "[{}]".format(self.params.printState()) return s def printProgress(self): return "[{}] {}".format(self._progress_code, self._progress_msg) def specifyParamDefault(self, key, values): """ @brief Specify a value and set it as default value too @param key (string) Parameter key @param values Parameter value(s) """ if not self._params.hasParam(key): log.error( "specifyParamDefault", "No param '{}' found. Debug: {}".format( key, self.printInfo(True))) self._params.specifyDefault(key, values) def specifyParam(self, key, values): """ @brief Specify a parameter and update the input cache @param key (string) Parameter key @param values Parameter value(s) """ if not self._params.hasParam(key): log.error( "specifyParam", "No param '{}' found. Debug: {}".format( key, self.printInfo(True))) self._params.specify(key, values) def specifyParamsDefault(self, input_params): """ @brief Set the parameters and makes them default (they will no more be overwritten by specifyParams, even with keep_offline=False) @param input_params (dict) """ self._params.specifyParamsDefault(input_params) def specifyParams(self, input_params, keep_default=True): """ @brief Set the parameters @param input_params (dict) Parameters to set @param keep_default (bool) If True, params already specified are preserved """ self._params.specifyParams(input_params, keep_default) # -------- User's functions-------- def setDescription(self, description, label=""): """ @brief Description is a SkillDescription """ self._description = description self._type = description._type if label != "": self._label = label self._resetDescription() def startError(self, msg, code): """ @brief signal an error during the starting routine """ assert type(msg) == str assert type(code) == int self.fail(msg, code) return False def step(self, msg=""): """ @brief Set a running breakpoint """ assert type(msg) == str self._setProgress(msg) return State.Running def fail(self, msg, code): """ @brief Set a failure state """ assert type(msg) == str assert type(code) == int if code > 0: code *= -1 self._setProgress(msg, code) return State.Failure def success(self, msg=""): """ @brief Set a success state """ assert type(msg) == str self._setProgress(msg) return State.Success # -------- Virtual functions-------- def modifyDescription(self, skill): """ @brief Override to define additional parameters/condition over the skill """ pass def onReset(self): """ @brief Called when resetting. """ pass def onStart(self): """ @brief Called just before 1st execute @return (Bool) """ return True def onPreempt(self): """ @brief Called when skill is requested to stop. @return (State) """ self._setProgress("Preempted", -1) return State.Failure def execute(self): """ @brief Main execution function @return (State) """ raise NotImplementedError("Not implemented in abstract class") def onEnd(self): """ @brief Called just after last execute or preemption @return (Bool) """ return True
class BtTicker: """ Manager of a set of Behavior Trees (Tasks) and a visitor Ticks the tasks sequentially, with the specified visitor Provides interfaces to start, pause, stop the ticking process and to add/remove tasks """ _verbose = True _tasks_to_preempt = list() _tasks_to_pause = dict() _tasks = {} _process = None _visitor = None _id_gen = IdGen() _progress_cb = None _tick_cb = None _finished_skill_ids = dict() def _run(self, _): """ @brief Tick tasks at 25hz """ BtTicker._finished_skill_ids = dict() rate = rospy.Rate(25) log.info("[BtTicker]", "Execution starts.") for uid in list(BtTicker._tasks.keys()): t = BtTicker._tasks[uid] printer = visitors.VisitorPrint(BtTicker._visitor._wm, BtTicker._visitor._instanciator) printer.traverse(t) self.publish_progress(uid, printer) while BtTicker._tasks: self._tick() rate.sleep() self._tick_cb() log.info("[BtTicker]", "Execution stops.") def _tick(self): visitor = BtTicker._visitor for uid in list(BtTicker._tasks.keys()): if uid in BtTicker._tasks_to_preempt: BtTicker._tasks_to_preempt.remove(uid) visitor.preempt() if uid in BtTicker._tasks_to_pause.keys(): if BtTicker._tasks_to_pause[uid]>0: BtTicker._tasks_to_pause[uid]-=1 else: continue t = BtTicker._tasks[uid] result = visitor.traverse(t) self.publish_progress(uid, visitor) if result != State.Running and result != State.Idle: self.remove_task(uid) def kill(self): if not BtTicker._process is None: del BtTicker._process BtTicker._process = None self._tick() def is_running(self): if BtTicker._process is None: return False return BtTicker._process.is_alive() def publish_progress(self, uid, visitor): finished_skill_ids = BtTicker._finished_skill_ids for (id, desc) in visitor.snapshot(): #TODO: check timings when removing the filtering # if id in finished_skill_ids: # if finished_skill_ids[id]['state'] == desc['state'] and finished_skill_ids[id]['msg'] == desc['msg']: # continue # finished_skill_ids[id] = desc self._progress_cb(task_id=uid, id=id, **desc) def observe_progress(self, func): self._progress_cb = func def observe_tick(self, func): self._tick_cb = func def clear(self): if BtTicker._visitor: BtTicker._visitor.preempt() BtTicker._process.join() BtTicker._visitor = None BtTicker._tasks.clear() BtTicker._id_gen.clear() def add_task(self, obj, desired_id=-1): uid = BtTicker._id_gen.getId(desired_id) obj._label = "task_{}".format(uid) BtTicker._tasks[uid] = obj return uid def remove_task(self, uid): BtTicker._tasks.pop(uid) BtTicker._id_gen.removeId(uid) def start(self, visitor, uid): if uid in BtTicker._tasks_to_pause: log.info("[start]", "Resuming task {}.".format(uid)) del BtTicker._tasks_to_pause[uid] else: log.info("[start]", "Starting task {}.".format(uid)) if not self.is_running(): BtTicker._visitor = visitor BtTicker._process = Process(target=BtTicker._run, args=(self, True)) BtTicker._process.start() return True def join(self): BtTicker._process.join() def pause(self, uid): log.info("[pause]", "Pausing task {}.".format(uid)) BtTicker._tasks_to_pause[uid] = 0 def tick_once(self, uid): log.info("[tick_once]", "Tick once task {}.".format(uid)) BtTicker._tasks_to_pause[uid] = 1 def preempt(self, uid): log.info("[preempt]", "Stopping task {}...".format(uid)) if uid in BtTicker._tasks_to_pause: del BtTicker._tasks_to_pause[uid] BtTicker._tasks_to_preempt.append(uid) starttime = rospy.Time.now() timeout = rospy.Duration(5.0) while(self.is_running() and rospy.Time.now() - starttime < timeout): rospy.sleep(0.01) if self.is_running(): log.info("preempt", "Task {} is not answering. Killing process.".format(uid)) self.kill() log.info("preempt", "Task {} preempted.".format(uid)) def preempt_all(self): for uid in list(BtTicker._tasks.keys()): self.preempt(uid) def pause_all(self): for uid in list(BtTicker._tasks.keys()): self.pause(uid) def tick_once_all(self): for uid in list(BtTicker._tasks.keys()): self.tick_once(uid)
def __init__(self, wmi, scene_name=None): self._id_gen = IdGen() self._wmi = wmi self._keep_sync = False self._verbose = False self.reset(scene_name)
class WorldModel: """ This world model implementation is made to remain local and interface when necessary with the global wm """ _id = 0 _graph = sn.Graph() _types = {} def __init__(self, wmi, scene_name=None): self._id_gen = IdGen() self._wmi = wmi self._keep_sync = False self._verbose = False self.reset(scene_name) def __copy__(self): wm = WorldModel(self._wmi) wm._verbose = self._verbose wm._id = self._id wm._graph = self._graph wm._types = self._types return wm def __deepcopy__(self, memo): result = self.__copy__() memo[id(self)] = result return result def __enter__(self): self.syncKeep(True) return self def __exit__(self, type, value, traceback): self.syncKeep(False) def reset(self, scene_name=None): self._id = 0 self._graph = sn.Graph() self._types = dict() if scene_name: root = Element(":Scene", scene_name, 0) props = {"type": root._type, "label": root._label} self._graph.add_node(dict(chain(props.items(), root._properties.items())), root._id) def _addType(self, etype, eid): if etype not in self._types: self._types[etype] = [] self._types[etype].append(eid) def _removeType(self, etype, eid): try: self._types[etype].remove(eid) except BaseException: log.error("_removeType", "No element id: {} type: {}".format(eid, etype)) def _getTypes(self, etype): """ Fast retrieval of elements of same type """ to_ret = [] if etype in self._types: to_ret += [self._graph.get_node(t) for t in self._types[etype] if self._graph.has_node(t)] for c in self._wmi.get_sub_classes(etype, True): if c in self._types: to_ret += [self._graph.get_node(t) for t in self._types[c] if self._graph.has_node(t)] #print "{} {}".format(etype, len(to_ret)) return to_ret def syncKeep(self, value=True): """ When set, the modifications are syncronized with main wm """ self._keep_sync = value def pushElement(self, e, action): """ Update element to main wm """ if self._keep_sync: #print "Pushing {} {}".format(e.printState(), action) if action == "add": self._wmi.add_element(e) elif action == "update": e._relations = self.getContextRelations(e) self._wmi.update_element(e) elif action == "remove": self._wmi.remove_element(e) def pushRelation(self, sub, rel, obj, value): """ Update relation to main wm """ if self._keep_sync: #print "Pushing {} {} {} {}".format(sub, rel, obj, value) self._wmi.set_relation(sub, rel, obj, value) def sync(self): """ Pull the graph from the main world model """ self.importGraph(self._wmi.get_branch("skiros:Scene-0")) def importGraph(self, elements): self.reset() for e in elements: self._addNode(e) for e in elements: for r in e._relations: try: if r['src'] == "-1": self._addEdge(e._id, r['type'], r['dst']) # NOTE: i have to skip passive relations...this could create problems except BaseException: log.error("[importGraph]", "Skipping relation {}. The child node was not imported.".format(r)) continue def importRelations(self, relations): for r in relations: self.set_relation(*r) def _printRecursive(self, root, indend, relation_filter): s = root.printState() print indend + s indend = "-" * (len(indend) + len(s)) + "->" for e in self.getChildren(root._id, relation_filter): # sceneProperty self._printRecursive(e, indend, relation_filter) def printModel(self, relation_filter="skiros:sceneProperty"): root = self.get_element("skiros:Scene-0") #print str(self._graph) self._printRecursive(root, "", relation_filter) # nx.draw(self._graph.networkx_graph()) return def getAbstractElement(self, etype, elabel): e = Element(etype, elabel) self.add_element(e, 0, 'hasAbstract') return e def resolve_elements2(self, keys, ph): """ Return all elements matching the profile in input (type, label, properties and relations) Keys: a key list pointing out the params to be resolved ph: a ParamHandler class """ first = {} couples = {} print_out = False for key in keys: first[key] = np.array(self.resolve_element(ph.getParamValue(key))) if not first[key].any(): log.warn("resolve_elements", "No input found for param {}. Resolving: {}".format(key, ph.getParamValue(key).printState(True))) all_keys = [key for key, _ in ph._params.iteritems()] coupled_keys = [] overlap_keys = [] relations_done = set([]) # Build tuples of concording parameters for i in range(len(all_keys)): # Loop over all keys key_base = all_keys[i] if not isinstance(ph.getParamValue(key_base), Element): continue for j in ph.getParamValue(key_base)._relations: # Loop over relation constraints #print j if j["src"] == "-1": # -1 is the special autoreferencial value key2 = j["dst"] key = key_base rel_id = key_base + j["type"] + j["dst"] if rel_id in relations_done: # Skip relation with previous indexes, already considered continue else: #print rel_id relations_done.add(rel_id) else: key2 = key_base key = j["src"] rel_id = j["src"] + j["type"] + key_base if rel_id in relations_done: # Skip relation with previous indexes, already considered continue else: #print rel_id relations_done.add(rel_id) if not ph.hasParam(key) or not ph.hasParam(key2): # Check necessary because at the moment ._relations contains a mix Toclean continue this = ph.getParamValue(key) other = ph.getParamValue(key2) #print "{} {}".format(key, key2) if this.getIdNumber() >= 0 and other.getIdNumber() >= 0: # If both parameters are already set, no need to resolve.. continue if this.getIdNumber() >= 0: set1 = [this] else: if ph.getParam(key).paramType() == params.ParamTypes.Optional: continue else: set1 = first[key] if other.getIdNumber() >= 0: set2 = [other] else: if ph.getParam(key2).paramType() == params.ParamTypes.Optional: continue else: set2 = first[key2] if (key, key2) in couples: temp = [np.array([e1, e2]) for e1 in set1 for e2 in set2 if bool(self.get_relations(e1._id, j["type"], e2._id)) == j['state']] if temp: couples[(key, key2)] = np.concatenate(couples[(key, key2)], np.array(temp)) else: log.warn("resolve_elements", "No input for params {} {}. Resolving: {} {}".format(key, key2, ph.getParamValue(key).printState(True), ph.getParamValue(key2).printState(True))) else: if key in coupled_keys: overlap_keys.append(key) else: coupled_keys.append(key) if key2 in coupled_keys: overlap_keys.append(key2) else: coupled_keys.append(key2) temp = [np.array([e1, e2]) for e1 in set1 for e2 in set2 if bool(self.get_relations(e1._id, j["type"], e2._id)) == j['state']] couples[(key, key2)] = np.array(temp) if not temp: log.warn("resolve_elements", "No input for params {} {}. Resolving: {} {}".format(key, key2, ph.getParamValue(key).printState(True), ph.getParamValue(key2).printState(True))) # Merge the tuples with an overlapping key if overlap_keys: loop = True iters = 5 while loop: # Iterate until no shared keys are found iters -= 1 if iters == 0: raise loop = False coupled_keys2 = [] merged = {} #print 'qui:' for k1, s1 in couples.iteritems(): for k2, s2 in couples.iteritems(): shared_k = [k for k in k1 if k in k2] if k1 == k2 or not shared_k: continue loop = True skip = True for i in k1: if not i in coupled_keys2: coupled_keys2.append(i) skip = False for i in k2: if not i in coupled_keys2: coupled_keys2.append(i) skip = False if skip: continue # If it was already considered, skip rk, rs = self._intersect(k1, k2, s1, s2, shared_k) merged[rk] = rs # Temporary store merged tuple for key in keys: # Add not merged tuples if not key in coupled_keys2: for k1, s1 in couples.iteritems(): if key in k1: merged[k1] = s1 couples = merged # Add back keys that are not coupled to others for key in keys: if not key in coupled_keys: couples[key] = first[key] if print_out: for k, v in couples.iteritems(): s = "{}:".format(k) for i in v: if not isinstance(i, Element): s += "[" for j in i: s += "{},".format(j) s += "]" else: s += "{},".format(i) print s return couples def _concatenate(self, a, b): if not isinstance(a, np.ndarray): a = np.array([a]) if not isinstance(b, np.ndarray): b = np.array([b]) return np.concatenate((a, b)) def _intersect(self, k1, k2, s1, s2, shared_k): a = [k1.index(k) for k in shared_k] b = [k2.index(k) for k in shared_k] c = np.arange(len(k1)) d = np.arange(len(k2)) d = np.delete(d, b) keys = [] # Remove constant sets for k in k1: keys.append(k) for k in k2: if not k in shared_k: keys.append(k) #print keys sets = [] #print c #print d for v1 in s1: for v2 in s2: append = True for i in range(len(shared_k)): #print str(v1[a[i]].printState()) + 'vs' + str(v1[b[i]].printState()) + '=' + str(v1[a[i]]!=v2[b[i]]) if v1[a[i]] != v2[b[i]]: append = False if append: sets.append(np.array(self._concatenate(v1[c], v2[d]))) return tuple(keys), np.array(sets) def resolve_element(self, description): """ Return all elements matching the profile in input (type, label, properties) """ first = [] to_ret = [] #print 'description ' + description.printState(True) # Get all nodes matching type and label #print get_sub_classes(STMN[description._type], True) for e in self._getTypes(description._type): if description._label == "" or description._label == "Unknown" or e['label'] == description._label: first.append(self._makeElement(e)) # Filter by properties for e in first: add = True for k, p in description._properties.iteritems(): if not e.hasProperty(k): add = False break for v in p.getValues(): if v == "" or v is None: break if e.getProperty(k).find(v) < 0: add = False break if not add: break if add: to_ret.append(e) #print to_ret return to_ret def _makeElement(self, props): e = Element() copy = deepcopy(props) e._id = copy.pop("id") e._type = copy.pop("type") e._label = copy.pop("label") e._properties = copy return e def get_element(self, eid): try: eprops = self._graph.get_node(eid) except KeyError: raise KeyError("{} not found. Debug: {} {}".format(eid, self._graph, self._types)) return self._makeElement(eprops) def _addNode(self, element): element.setUri(self._id_gen.getId(element.getIdNumber())) if self._verbose: log.debug('add', str(element._id)) props = {"type": element._type, "label": element._label} self._graph.add_node(dict(chain(props.items(), element._properties.items())), element._id) self._addType(element._type, element._id) def _resolve_local_relations(self, e, lr): for r in lr: sub_e = r['dst'] sub_e.addRelation(e._id, r['type'], "-1") if sub_e._id == "": if self.add_element2(sub_e) < 0: log.error("[{}]".format(self.__class__.__name__), "Failed to add local element {}".format(sub_e)) else: if self.update_element(sub_e) < 0: log.error("[{}]".format(self.__class__.__name__), "Failed to update local element {}".format(sub_e)) def add_element2(self, element): lr = copy(element._local_relations) element._local_relations = list() self.pushElement(element, "add") self._addNode(element) for r in element._relations: if r['src'] == "-1": self.set_relation(element._id, r['type'], r['dst'], True, push=False) else: self.set_relation(r['src'], r['type'], element._id, True, push=False) self._resolve_local_relations(element, lr) return element._id def add_element(self, element, parent_id, relation): self.pushElement(element, "add") self._addNode(element) self.set_relation(parent_id, relation, element._id, True, push=False) return element._id def update_element(self, element): for r in element._relations: if r['src'] == "-1": self.set_relation(element._id, r['type'], r['dst'], True, push=False) else: self.set_relation(r['src'], r['type'], element._id, True, push=False) self.pushElement(element, "update") if not self._graph.has_node(element._id): log.warn("update_element", "No element found with key {}".format(element._id)) return props = {"type": element._type, "label": element._label} self._graph.add_node(dict(chain(props.items(), element._properties.items())), element._id) return element._id def remove_element(self, eid): if self._verbose: log.debug('remove', str(eid)) self.pushElement(eid, "remove") eprops = self._graph.get_node(eid) self._removeType(eprops["type"], eid) self._graph.remove_node(eid) def _check_relation(self, esubject, relation, eobject, value, push): """ Remove the old contain relation, to maintain the tree structure """ if(self.isRelationType(relation, "skiros:sceneProperty") and value): self.set_relation("-1", "skiros:sceneProperty", eobject, False, push) def isElementType(self, etype, abstract_type): return etype == abstract_type or (self._wmi.addPrefix(etype) in self._wmi.get_sub_classes(abstract_type, True)) def isRelationType(self, relation, rtype="skiros:sceneProperty"): #print "{}={} is {}".format(relation, rtype, relation==rtype or self._wmi.addPrefix(relation) in self._wmi.get_sub_properties(rtype, True)) return relation == rtype or self._wmi.addPrefix(relation) in self._wmi.get_sub_properties(rtype, True) def _addEdge(self, esubject, relation, eobject): if self._verbose: log.debug('add', str(esubject) + "-" + relation + "-" + str(eobject)) self._graph.add_edge(esubject, eobject, {"type": relation}) def set_relation(self, esubject, relation, eobject, value=True, push=True): if self.get_relations(esubject, relation, eobject) and value: # Avoid adding twice the same statement return True self._check_relation(esubject, relation, eobject, value, push) try: if value: self._addEdge(esubject, relation, eobject) else: for e in self.get_relations(esubject, relation, eobject, True): self._graph.remove_edge(e) if push: self.pushRelation(esubject, relation, eobject, value) except BaseException: self.printModel() raise return True def getAssociatedReasoner(self, relation): for cls in DiscreteReasoner.__subclasses__(): instance = cls() if relation in instance.getAssociatedRelations(): return instance return None def get_relations(self, esubject, relation, eobject, getId=False): rel = [] for _, edge in self._graph.get_edges().items(): if (esubject == "" or edge['src'] == esubject) and (eobject == "" or edge['dst'] == eobject) and (relation == "" or self.isRelationType(edge['type'], relation)): if getId: rel.append(edge['id']) else: new_edge = deepcopy(edge) rel.append(new_edge) if not getId and relation != "" and esubject != "" and eobject != "": try: s = self.get_element(esubject) o = self.get_element(eobject) reasoner = s._getReasoner(relation) if relation in reasoner.computeRelations(s, o): rel.append({"src": esubject, "type": relation, "dst": eobject}) except KeyError: pass return rel def getContextRelations(self, esubject): """ Get all relations related to a subject """ rel = [] for _, edge in self._graph.get_edges().items(): if edge['src'] == esubject._id: new_edge = deepcopy(edge) del new_edge['id'] new_edge['src'] = "-1" rel.append(new_edge) elif edge['dst'] == esubject._id: new_edge = deepcopy(edge) del new_edge['id'] new_edge['dst'] = "-1" rel.append(new_edge) return rel def getChildren(self, eid, relation="skiros:sceneProperty"): to_ret = [] for edge in self.get_relations(eid, relation, ""): e = self.get_element(edge['dst']) to_ret.append(e) return to_ret def getParent(self, eid): for edge in self.get_relations("", "skiros:sceneProperty", eid, getReasonersRel=False): return self.get_element(edge['src'])