def union( self, other_during: "DuringAction[_ObjectT]" ) -> "DuringAction[_ObjectT]": """ Unify two DuringAction together. For unifying spatial paths, the paths from `self` override any conflicts in `other_during` """ objects_to_paths = immutablesetmultidict( chain(self.objects_to_paths.items(), other_during.objects_to_paths.items())) # We want to see if we need to unify any spatial paths objects_to_unified_paths: List[Tuple[_ObjectT, SpatialPath[_ObjectT]]] = [] paths_to_skip: Set[SpatialPath[_ObjectT]] = set() for obj in objects_to_paths.keys(): paths = objects_to_paths[obj] for (num, path) in enumerate(paths): if path not in paths_to_skip: for i in range(num, len(paths)): if (path.reference_source_object == paths[i].reference_source_object): path.unify(paths[i], override=True) paths_to_skip.add(paths[i]) objects_to_unified_paths.append((obj, path)) return DuringAction( objects_to_paths=immutablesetmultidict(objects_to_unified_paths), at_some_point=chain(self.at_some_point, other_during.at_some_point), continuously=chain(self.continuously, other_during.continuously), )
def _init_objects_to_actions( self ) -> ImmutableSetMultiDict[ObjectSemanticNode, AttributeSemanticNode]: return immutablesetmultidict( flatten([(slot_filler, action) for slot_filler in action.slot_fillings.values()] for action in self.actions))
class _OverLappingHasSpanIndex(HasSpanIndex[T]): """ An implementation of ``HasSpanIndex`` for items whose spans may overlap. """ _span_to_item_index: ImmutableSetMultiDict[Span, T] = attrib( converter=immutablesetmultidict, default=immutablesetmultidict() # type: ignore ) def get_exactly_matching(self, span: Span) -> ImmutableSet[T]: return self._span_to_item_index[span] def get_overlapping(self, span: Span) -> ImmutableSet[T]: return immutableset( item for candidate_span in self._span_to_item_index for item in self._span_to_item_index[candidate_span] if candidate_span.overlaps(span)) def get_contained(self, span: Span) -> ImmutableSet[T]: return immutableset( item for candidate_span in self._span_to_item_index for item in self._span_to_item_index[candidate_span] if span.contains_span(candidate_span)) def get_containing(self, span: Span) -> ImmutableSet[T]: return immutableset( item for candidate_span in self._span_to_item_index for item in self._span_to_item_index[candidate_span] if candidate_span.contains_span(span))
def get_largest_matching_pattern( pattern: PerceptionGraphPattern, graph: PerceptionGraph, *, debug_callback: Optional[DebugCallableType] = None, graph_logger: Optional[GraphLogger] = None, ontology: Ontology, match_ratio: Optional[float] = None, match_mode: MatchMode, trim_after_match: Optional[Callable[[PerceptionGraphPattern], PerceptionGraphPattern]] = None, allowed_matches: ImmutableSetMultiDict[Any, Any] = immutablesetmultidict(), ) -> Optional[PerceptionGraphPattern]: """ Helper function to return the largest matching `PerceptionGraphPattern` for learner from a perception pattern and graph pair.""" matching = pattern.matcher( graph, debug_callback=debug_callback, match_mode=match_mode, allowed_matches=allowed_matches, ) return matching.relax_pattern_until_it_matches( graph_logger=graph_logger, ontology=ontology, min_ratio=match_ratio, trim_after_match=trim_after_match, )
def _to_immutablesetmultidict( val: Optional[Union[Iterable[Tuple[Any, Any]], Mapping[Any, Any], ImmutableSetMultiDict[Any, Any]]] ) -> ImmutableSetMultiDict[Any, Any]: """Needed until https://github.com/python/mypy/issues/5738 and https://github.com/python-attrs/attrs/issues/519 are fixed. """ return immutablesetmultidict(val)
class Action(Generic[_ActionTypeT, _ObjectT]): r""" An action. This can be bound to `SituationObject` to represent actions in `Situation`\ s or to `TemplateObjectVariable`\ s to represent actions in situation templates. """ action_type: _ActionTypeT = attrib() argument_roles_to_fillers: ImmutableSetMultiDict[OntologyNode, Union[ _ObjectT, Region[_ObjectT]]] = attrib(converter=_to_immutablesetmultidict, default=immutablesetmultidict()) r""" A mapping of semantic roles (given as `OntologyNode`\ s) to their fillers. There may be multiple fillers for the same semantic role (e.g. conjoined arguments). """ # the optional below seems to confuse mypy? during: Optional[DuringAction[_ObjectT]] = attrib( # type: ignore validator=optional(instance_of(DuringAction)), default=None, kw_only=True) auxiliary_variable_bindings: ImmutableDict[ActionDescriptionVariable, _ObjectT] = attrib( converter=_to_immutabledict, default=immutabledict(), kw_only=True) """ A mapping of action variables from *action_type*'s `ActionDescription` to the items which should fill them. """ def accumulate_referenced_objects( self, object_accumulator: List[_ObjectT]) -> None: for (_, filler) in self.argument_roles_to_fillers.items(): if isinstance(filler, Region): filler.accumulate_referenced_objects(object_accumulator) else: object_accumulator.append(filler) if self.during: self.during.accumulate_referenced_objects(object_accumulator) for aux_var_binding in self.auxiliary_variable_bindings.values(): if isinstance(aux_var_binding, Region): aux_var_binding.accumulate_referenced_objects( object_accumulator) else: object_accumulator.append(aux_var_binding) def __repr__(self) -> str: parts = [str(self.argument_roles_to_fillers)] if self.during: parts.append(f"during={self.during}") if self.auxiliary_variable_bindings: parts.append( f"auxiliary_variable_bindings={self.auxiliary_variable_bindings}" ) return f"{self.action_type}({', '.join(parts)})"
def index(items: Iterable[T]) -> "HasSpanIndex[T]": """ Creates a ``HasSpanIndex`` for the given items. """ return _OverLappingHasSpanIndex( # mypy is confused immutablesetmultidict( ((item.span, item) for item in items)) # type: ignore )
class ActionDescription: frame: ActionDescriptionFrame = attrib( validator=instance_of(ActionDescriptionFrame), kw_only=True ) # nested generic in optional seems to be confusing mypy during: Optional[DuringAction[ActionDescriptionVariable]] = attrib( # type: ignore validator=optional(instance_of(DuringAction)), default=None, kw_only=True ) # conditions which hold both before and after the action enduring_conditions: ImmutableSet[Relation[ActionDescriptionVariable]] = attrib( converter=flatten_relations, default=immutableset(), kw_only=True ) # Preconditions preconditions: ImmutableSet[Relation[ActionDescriptionVariable]] = attrib( converter=flatten_relations, default=immutableset(), kw_only=True ) # Postconditions postconditions: ImmutableSet[Relation[ActionDescriptionVariable]] = attrib( converter=flatten_relations, default=immutableset(), kw_only=True ) # Asserted properties of objects in action asserted_properties: ImmutableSetMultiDict[ ActionDescriptionVariable, OntologyNode ] = attrib( converter=_to_immutablesetmultidict, default=immutablesetmultidict(), kw_only=True ) auxiliary_variables: ImmutableSet[ActionDescriptionVariable] = attrib(init=False) """ These are variables which do not occupy semantic roles but are are still referred to by conditions, paths, etc. An example would be the container for liquid for a "drink" action. """ def __attrs_post_init__(self) -> None: for relation in chain( self.enduring_conditions, self.preconditions, self.postconditions ): if not isinstance(relation, Relation): raise RuntimeError( f"All conditions on an action description ought to be Relations " f"but got {relation}" ) @auxiliary_variables.default def _init_auxiliary_variables(self): auxiliary_variables: List[ActionDescriptionVariable] = [] if self.during: self.during.accumulate_referenced_objects(auxiliary_variables) for relation in chain( self.enduring_conditions, self.preconditions, self.postconditions ): relation.accumulate_referenced_objects(auxiliary_variables) return immutableset( variable for variable in auxiliary_variables if variable not in self.frame.variables_to_roles )
def tokens_to_mentions( tokens: TokenTheory, mentions: MentionTheory ) -> ImmutableSetMultiDict[Token, Mention]: mention_index = HasSpanIndex.index(mentions) ret = [] for token in tokens: for mention in mention_index.get_containing(token): ret.append((token, mention)) return immutablesetmultidict(ret)
def copy_remapping_objects( self, object_map: Mapping[_ObjectT, _ObjectToT] ) -> "AxesInfo[_ObjectToT]": return AxesInfo( addressee=object_map[self.addressee] if self.addressee else None, axes_facing=immutablesetmultidict( (object_map[key], value) for (key, value) in self.axes_facing.items() ), )
def _init_concepts_to_templates( self) -> ImmutableSetMultiDict[Concept, SurfaceTemplate]: # Ground is added explicitly to this list because the code # Which matches the ground matches by recognition and not shape # See: `ObjectRecognizer.match_objects` return immutablesetmultidict(( concept, SurfaceTemplate.for_object_name(name, language_mode=self._language_mode), ) for (concept, name) in ( list(self._object_recognizer._concepts_to_names.items() # pylint:disable=protected-access ) + [(GROUND_OBJECT_CONCEPT, "ground")]))
class Digraph: """A directed graph implementation. Requirements: - The edges are expected to be in successor form: for each key node, its value nodes are all being pointed to. Worded another way, each edge must be in (node_from, node_to) form. `predecessors` is the inverse. - The nodes that participate in the edges must appear in the master node list. """ nodes: ImmutableSet[str] = attrib(converter=_to_immutableset, default=immutableset()) edges: ImmutableSetMultiDict[str, str] = attrib( converter=_to_immutablesetmultidict, default=immutablesetmultidict(), validator=validate_edges, ) predecessors: ImmutableSetMultiDict[str, str] = attrib(init=False) def in_degree(self): return InDegreeView(self) def topological_sort(self) -> Iterator[str]: """Algorithm adapted from NetworkX https://github.com/networkx/networkx/blob/39a1c6f5471cd3adf476a3bd5355dcaa2e8a6160/networkx/algorithms/dag.py#L121 """ indegree_map = {v: d for v, d in self.in_degree() if d > 0} zero_indegree = [v for v, d in self.in_degree() if d == 0] while zero_indegree: node = zero_indegree.pop() for child in self.edges[node]: indegree_map[child] -= 1 if indegree_map[child] == 0: zero_indegree.append(child) del indegree_map[child] yield node # Because this method is only for supporting parameter interpolation, provide a # user-friendly error message here to avoid needing access to `indegree_map` externally. if indegree_map: raise ParameterInterpolationError( "These interpolated parameters form at least one graph cycle " f"that must be fixed: {tuple(indegree_map.keys())}" ) @predecessors.default def init_predecessors(self) -> ImmutableSetMultiDict[str, str]: return self.edges.invert_to_set_multidict()
class AxesInfo(Generic[_ObjectT], CanRemapObjects[_ObjectT]): addressee: Optional[_ObjectT] = attrib(default=None) axes_facing: ImmutableSetMultiDict[_ObjectT, GeonAxis] = attrib( converter=_to_immutablesetmultidict, default=immutablesetmultidict() ) def copy_remapping_objects( self, object_map: Mapping[_ObjectT, _ObjectToT] ) -> "AxesInfo[_ObjectToT]": return AxesInfo( addressee=object_map[self.addressee] if self.addressee else None, axes_facing=immutablesetmultidict( (object_map[key], value) for (key, value) in self.axes_facing.items() ), )
def _find_partial_match( self, hypothesis: PerceptionGraphTemplate, graph: PerceptionGraph, *, required_alignments: Mapping[SyntaxSemanticsVariable, ObjectSemanticNode], ) -> "AbstractPursuitLearnerNew.PartialMatch": pattern = hypothesis.graph_pattern hypothesis_pattern_common_subgraph = get_largest_matching_pattern( pattern, graph, debug_callback=self._debug_callback, graph_logger=self._hypothesis_logger, ontology=self._ontology, match_mode=MatchMode.NON_OBJECT, allowed_matches=immutablesetmultidict( [ (hypothesis.template_variable_to_pattern_node[variable], object_node) for variable, object_node in required_alignments.items() ] ), ) self.debug_counter += 1 leading_hypothesis_num_nodes = len(pattern) num_nodes_matched = ( len(hypothesis_pattern_common_subgraph.copy_as_digraph().nodes) if hypothesis_pattern_common_subgraph else 0 ) if hypothesis_pattern_common_subgraph: partial_hypothesis: Optional[ PerceptionGraphTemplate ] = PerceptionGraphTemplate( graph_pattern=hypothesis_pattern_common_subgraph, template_variable_to_pattern_node=hypothesis.template_variable_to_pattern_node, ) else: partial_hypothesis = None return PursuitAttributeLearnerNew.AttributeHypothesisPartialMatch( partial_hypothesis, num_nodes_matched=num_nodes_matched, num_nodes_in_pattern=leading_hypothesis_num_nodes, )
def _update_hypothesis( self, previous_pattern_hypothesis: PerceptionGraphTemplate, current_pattern_hypothesis: PerceptionGraphTemplate, ) -> Optional[PerceptionGraphTemplate]: return previous_pattern_hypothesis.intersection( current_pattern_hypothesis, ontology=self._ontology, match_mode=MatchMode.NON_OBJECT, allowed_matches=immutablesetmultidict([ (node2, node1) for previous_slot, node1 in previous_pattern_hypothesis. template_variable_to_pattern_node.items() for new_slot, node2 in current_pattern_hypothesis. template_variable_to_pattern_node.items() if previous_slot == new_slot ]), )
def _orientation_to_size( self, piece: "Piece") -> ImmutableSetMultiDict[int, int]: rotation_to_size: List[Tuple[int, int]] = [] for num, lines in enumerate(piece.shape): last_line = None for line in lines: if "0" in line: last_line = line if last_line: count = 0 for char in line: if char == "0": count = count + 1 rotation_to_size.append((count, num)) else: raise RuntimeError("Failed to find last line of shape.") return immutablesetmultidict(rotation_to_size)
def test_initialization(self) -> None: self.assertEqual( self.GRAPH.predecessors, immutablesetmultidict(( ("2", "11"), ("8", "3"), ("8", "7"), ("9", "8"), ("9", "11"), ("10", "3"), ("10", "11"), ("11", "5"), ("11", "7"), )), ) with self.assertRaisesRegex( RuntimeError, f"These nodes are not in the master list: {immutableset(['3'])}" ): Digraph(nodes=("1", "2"), edges=(("1", "2"), ("1", "3")))
def test_pickling(self): self.assertEqual( pickle.loads( pickle.dumps(immutablesetmultidict([(1, (2, 2, 3, 6)), (4, (5, 6))])) ), immutablesetmultidict([(1, (2, 2, 3, 6)), (4, (5, 6))]), ) self.assertEqual( pickle.loads(pickle.dumps(immutablesetmultidict())), immutablesetmultidict() ) self.assertEqual( immutablesetmultidict([(1, (2, 2, 3, 6)), (4, (5, 6))]).__reduce__(), (immutablesetmultidict, (((1, (2, 2, 3, 6)), (4, (5, 6))),)), ) self.assertEqual( immutablesetmultidict().__reduce__(), (immutablesetmultidict, ((),)) )
class Ontology: r""" A hierarchical collection of types for objects, actions, etc. Types are represented by `OntologyNode`\ s with parent-child relationships. Every `OntologyNode` may have a set of properties which are inherited by all child nodes. Every `Ontology` must contain the special nodes `THING`, `RELATION`, `ACTION`, `PROPERTY`, `META_PROPERTY`, and `CAN_FILL_TEMPLATE_SLOT`. An `Ontology` must have an `ObjectStructuralSchema` associated with each `CAN_FILL_TEMPLATE_SLOT` `THING`. `ObjectStructuralSchema`\ ta are inherited, but any which are explicitly-specified will cause any inherited schemata to be ignored. To assist in creating legal `Ontology`\ s, we provide `minimal_ontology_graph`. """ _name: str = attrib(validator=instance_of(str)) _graph: DiGraph = attrib(validator=instance_of(DiGraph), converter=_copy_digraph) _structural_schemata: ImmutableSetMultiDict[ "OntologyNode", "ObjectStructuralSchema"] = attrib(converter=_to_immutablesetmultidict, default=immutablesetmultidict()) action_to_description: ImmutableSetMultiDict[ OntologyNode, ActionDescription] = attrib(converter=_to_immutablesetmultidict, default=immutablesetmultidict(), kw_only=True) relations: ImmutableSet[Relation[OntologyNode]] = attrib( converter=_to_immutableset, default=immutableset(), kw_only=True) subjects_to_relations: ImmutableSetMultiDict[ OntologyNode, Relation[OntologyNode]] = attrib(init=False) def __attrs_post_init__(self) -> None: for cycle in simple_cycles(self._graph): raise ValueError( f"The ontology graph may not have cycles but got {cycle}") for required_node in REQUIRED_ONTOLOGY_NODES: check_arg( required_node in self, f"Ontology lacks required {required_node.handle} node", ) # every sub-type of THING must either have a structural schema # or be a sub-type of something with a structural schema for thing_node in dfs_preorder_nodes(self._graph.reverse(copy=False), THING): if not any(node in self._structural_schemata for node in self.ancestors(thing_node)): # e.g. "milk" does not have a structural schema if self.has_all_properties(thing_node, [ CAN_FILL_TEMPLATE_SLOT ]) and not self.has_all_properties(thing_node, [IS_SUBSTANCE]): raise RuntimeError( f"No structural schema is available for {thing_node}") def ancestors(self, node: OntologyNode) -> Iterable[OntologyNode]: return chain([node], ancestors(self._graph.reverse(copy=False), node)) def structural_schemata( self, node: OntologyNode) -> AbstractSet[ObjectStructuralSchema]: for node in self.ancestors(node): if node in self._structural_schemata: return self._structural_schemata[node] return immutableset() def is_subtype_of(self, node: "OntologyNode", query_supertype: "OntologyNode") -> bool: """ Determines whether *node* is a sub-type of *query_supertype*. """ # graph edges run from sub-types to super-types return has_path(self._graph, node, query_supertype) def nodes_with_properties( self, root_node: "OntologyNode", required_properties: Iterable["OntologyNode"] = immutableset(), *, banned_properties: AbstractSet["OntologyNode"] = immutableset(), banned_ontology_types: AbstractSet["OntologyNode"] = immutableset(), ) -> ImmutableSet["OntologyNode"]: r""" Get all `OntologyNode`\ s which are a dominated by *root_node* (or are *root_node* itself) which (a) possess all the *required_properties*, (b) possess none of the *banned_properties*, either directly or by inheritance from a dominating node, and (c) are not identical to or descendants of any of *banned_ontology_types*. Args: root_node: the node to search the ontology tree at and under required_properties: the properties (as `OntologyNode`\ s) every returned node must have banned_properties: the properties (as `OntologyNode`\ s) which no returned node may have banned_ontology_types: nodes in the ontology which (together with the descendants) should never be returned. Returns: All `OntologyNode`\ s which are a dominated by *root_node* (or are *root_node* itself) which possess all the *required_properties* and none of the *banned_properties*, either directly or by inheritance from a dominating node, and are not contained in or dominated by any of *banned_ontology_types*. """ if root_node not in self._graph: raise RuntimeError( f"Cannot get object with type {root_node} because it does not " f"appear in the ontology {self}") return immutableset( node for node in dfs_preorder_nodes(self._graph.reverse( copy=False), root_node) if self.has_all_properties( node, required_properties, banned_properties=banned_properties) and node not in banned_ontology_types and not self.descends_from(node, banned_ontology_types)) def descends_from( self, node: "OntologyNode", query_ancestors: Union["OntologyNode", AbstractSet["OntologyNode"]], ) -> bool: if isinstance(query_ancestors, OntologyNode): query_ancestors_set = {query_ancestors} else: query_ancestors_set = query_ancestors # type: ignore nodes_to_check: List["OntologyNode"] = [] visited_nodes = {node} nodes_to_check.extend(self._graph.successors(node)) while nodes_to_check: node_to_check = nodes_to_check.pop() visited_nodes.add(node_to_check) if node_to_check in query_ancestors_set: return True nodes_to_check.extend( parent for parent in self._graph.successors(node_to_check) if parent not in visited_nodes) return False def has_property(self, node: "OntologyNode", query_property: "OntologyNode") -> bool: r""" Checks whether an `OntologyNode` has a given property either directly or via inheritance. Args: node: the `OntologyNode` being inquired about query_property: the property being inquired about. Returns: Whether *node* possesses *query_property*, either directly or via inheritance from a dominating node. """ return self.has_all_properties(node, (query_property, )) def has_all_properties( self, node: "OntologyNode", query_properties: Iterable["OntologyNode"], *, banned_properties: AbstractSet["OntologyNode"] = immutableset(), ) -> bool: r""" Checks if an `OntologyNode` has the given properties, either directly or by inheritance.. Args: node: the `OntologyNode` being inquired about query_properties: the properties being inquired about banned_properties: this function will return false if *node* contains any of these properties. Defaults to the empty set. Returns: Whether *node* possesses all of *query_properties* and none of *banned_properties*, either directly or via inheritance from a dominating node. """ if not query_properties and not banned_properties: return True node_properties = self.properties_for_node(node) return all(property_ in node_properties for property_ in query_properties) and not any( property_ in banned_properties for property_ in node_properties) def properties_for_node( self, node: "OntologyNode") -> ImmutableSet["OntologyNode"]: r""" Get all properties a `OntologyNode` possesses. Args: node: the `OntologyNode` whose properties you want. Returns: All properties `OntologyNode` possesses, whether directly or by inheritance from a dominating node. """ node_properties: List[OntologyNode] = list( node.non_inheritable_properties) cur_node = node while cur_node: node_properties.extend(cur_node.inheritable_properties) # need to make a tuple because we can't len() the returned iterator parents = tuple(self._graph.successors(cur_node)) if len(parents) == 1: cur_node = parents[0] elif parents: raise RuntimeError( f"Found multiple parents for ontology node {node}, which is " f"not yet supported") else: # we have reached a root break return immutableset(node_properties) def required_action_description( self, action_type: OntologyNode, semantic_roles: Iterable[OntologyNode]) -> ActionDescription: semantic_roles_set = immutableset(semantic_roles) descriptions_for_action_type = self.action_to_description[action_type] matching_descriptions = immutableset( description for description in descriptions_for_action_type if description.frame.semantic_roles == semantic_roles_set) if matching_descriptions: if len(matching_descriptions) == 1: return only(matching_descriptions) else: raise RuntimeError( f"Multiple action descriptions match action type " f"{action_type} and roles {semantic_roles_set}") else: available_frames: Any = [ immutableset(description.frame.roles_to_variables.keys()) for description in descriptions_for_action_type ] raise RuntimeError( f"No action descriptions match action type " f"{action_type} and roles {semantic_roles_set}. " f"Known frames for {action_type} are " f"{available_frames}") def __contains__(self, item: "OntologyNode") -> bool: return item in self._graph.nodes def __repr__(self) -> str: return f"Ontology({self._name})" @subjects_to_relations.default def _subjects_to_relations( self ) -> ImmutableSetMultiDict[OntologyNode, Relation[OntologyNode]]: return immutablesetmultidict( (relation.first_slot, relation) for relation in self.relations)
def _subjects_to_relations( self ) -> ImmutableSetMultiDict[OntologyNode, Relation[OntologyNode]]: return immutablesetmultidict( (relation.first_slot, relation) for relation in self.relations)
def _init_language_span_to_node( self) -> ImmutableSetMultiDict[Span, ObjectSemanticNode]: return immutablesetmultidict( (v, k) for (k, v) in self.node_to_language_span.items())
class DuringAction(Generic[_ObjectT]): objects_to_paths: ImmutableSetMultiDict[ _ObjectT, SpatialPath[_ObjectT]] = attrib(converter=_to_immutablesetmultidict, default=immutablesetmultidict(), kw_only=True) at_some_point: ImmutableSet[Relation[_ObjectT]] = attrib( converter=flatten_relations, default=immutableset(), kw_only=True) continuously: ImmutableSet[Relation[_ObjectT]] = attrib( converter=flatten_relations, default=immutableset(), kw_only=True) def copy_remapping_objects( self, object_mapping: Mapping[_ObjectT, _NewObjectT] ) -> "DuringAction[_NewObjectT]": return DuringAction( objects_to_paths=((object_mapping[object_], path.copy_remapping_objects(object_mapping)) for (object_, path) in self.objects_to_paths.items()), at_some_point=(relation.copy_remapping_objects(object_mapping) for relation in self.at_some_point), continuously=(relation.copy_remapping_objects(object_mapping) for relation in self.continuously), ) def accumulate_referenced_objects( self, object_accumulator: List[_ObjectT]) -> None: r""" Adds all objects referenced by this `DuringAction` to *object_accumulator*. """ for (_, path) in self.objects_to_paths.items(): path.accumulate_referenced_objects(object_accumulator) for relation in self.at_some_point: relation.accumulate_referenced_objects(object_accumulator) for relation in self.continuously: relation.accumulate_referenced_objects(object_accumulator) def union( self, other_during: "DuringAction[_ObjectT]" ) -> "DuringAction[_ObjectT]": """ Unify two DuringAction together. For unifying spatial paths, the paths from `self` override any conflicts in `other_during` """ objects_to_paths = immutablesetmultidict( chain(self.objects_to_paths.items(), other_during.objects_to_paths.items())) # We want to see if we need to unify any spatial paths objects_to_unified_paths: List[Tuple[_ObjectT, SpatialPath[_ObjectT]]] = [] paths_to_skip: Set[SpatialPath[_ObjectT]] = set() for obj in objects_to_paths.keys(): paths = objects_to_paths[obj] for (num, path) in enumerate(paths): if path not in paths_to_skip: for i in range(num, len(paths)): if (path.reference_source_object == paths[i].reference_source_object): path.unify(paths[i], override=True) paths_to_skip.add(paths[i]) objects_to_unified_paths.append((obj, path)) return DuringAction( objects_to_paths=immutablesetmultidict(objects_to_unified_paths), at_some_point=chain(self.at_some_point, other_during.at_some_point), continuously=chain(self.continuously, other_during.continuously), )
def _init_entities_to_roles( self ) -> ImmutableSetMultiDict[ActionDescriptionVariable, OntologyNode]: return immutablesetmultidict( (entity, role) for role, entity in self.roles_to_variables.items() )
def test_empty(self): empty = immutablesetmultidict() self.assertEqual(0, len(empty)) empty2 = ImmutableSetMultiDict.of(dict()) self.assertEqual(0, len(empty2)) self.assertEqual(empty, empty2)
def test_empty_singleton(self): empty1 = immutablesetmultidict() empty2 = immutablesetmultidict() self.assertIs(empty1, empty2) empty4 = ImmutableSetMultiDict.of(dict()) self.assertIs(empty1, empty4)
def _testing_schemata( nodes: Iterable[OntologyNode] ) -> ImmutableSetMultiDict[OntologyNode, ObjectStructuralSchema]: return immutablesetmultidict( (node, ObjectStructuralSchema(node, axes=WORLD_AXES)) for node in nodes)
def test_of(self): x = ImmutableSetMultiDict.of({1: [2, 2, 3], 4: [5, 6]}) self.assertEqual(ImmutableSet.of([2, 3]), x[1]) y = immutablesetmultidict([(1, 2), (1, 2), (1, 3), (4, 5), (4, 6)]) self.assertEqual(immutableset([2, 3]), y[1])
def test_hash(self): hash(immutablesetmultidict({1: [2, 2, 3], 4: [5, 6]}))
def _init_objects_to_attributes( self ) -> ImmutableSetMultiDict[ObjectSemanticNode, AttributeSemanticNode]: return immutablesetmultidict( (one(attribute.slot_fillings.values()), attribute) for attribute in self.attributes)
def test_allowed_matches_with_bad_partial_match(): """ Tests whether PatternMarching's allowed_matches functionality works as intended when a bad partial match is specified. """ target_object = BOX train_obj_object = object_variable("obj-with-color", target_object) obj_template = Phase1SituationTemplate( "colored-obj-object", salient_object_variables=[train_obj_object]) template = all_possible(obj_template, chooser=PHASE1_CHOOSER_FACTORY(), ontology=GAILA_PHASE_1_ONTOLOGY) train_curriculum = phase1_instances("all obj situations", situations=template) perceptual_representation = only(train_curriculum.instances())[2] perception = graph_without_learner( PerceptionGraph.from_frame(perceptual_representation.frames[0])) pattern1: PerceptionGraphPattern = PerceptionGraphPattern.from_graph( perception.subgraph_by_nodes({ cast(PerceptionGraphNode, node) for node in perception._graph.nodes # pylint: disable=protected-access if getattr(node, "debug_handle", None) == "box_0" })).perception_graph_pattern pattern2: PerceptionGraphPattern = PerceptionGraphPattern.from_graph( perception.subgraph_by_nodes({ cast(PerceptionGraphNode, node) for node in perception._graph.nodes # pylint: disable=protected-access if getattr(node, "debug_handle", None) in {"box_0", "the ground"} })).perception_graph_pattern pattern1_box: AnyObjectPerception = cast( AnyObjectPerception, only(node for node in pattern1._graph # pylint: disable=protected-access if getattr(node, "debug_handle", None) == "box_0"), ) pattern2_box: AnyObjectPerception = cast( AnyObjectPerception, only(node for node in pattern2._graph # pylint: disable=protected-access if getattr(node, "debug_handle", None) == "box_0"), ) pattern2_ground: AnyObjectPerception = cast( AnyObjectPerception, only(node for node in pattern2._graph # pylint: disable=protected-access if getattr(node, "debug_handle", None) == "the ground"), ) matcher = PatternMatching( pattern=pattern1, graph_to_match_against=pattern2, matching_pattern_against_pattern=True, match_mode=MatchMode.OBJECT, allowed_matches=immutablesetmultidict([(pattern1_box, pattern2_box)]), ) with pytest.raises(RuntimeError): first( matcher.matches( initial_partial_match={pattern1_box: pattern2_ground}, use_lookahead_pruning=True, ), None, )