def _get_all_entity_trees_of_cls_helper( tree: EntityTree, cls: Type[DatabaseEntity], seen_ids: Set[int], seen_trees: List[EntityTree], direction_checker: SchemaEdgeDirectionChecker): """ Finds all objects in the provided |tree| graph which have the type |cls|. When an object of type |cls| is found, updates the provided |seen_ids| and |seen_trees| with the object's id and EntityTree respectively. """ entity = tree.entity entity_cls = entity.__class__ # If |cls| is higher ranked than |entity_cls|, it is impossible to reach # an object of type |cls| from the current entity. if direction_checker.is_higher_ranked(cls, entity_cls): return if entity_cls == cls and id(entity) not in seen_ids: seen_ids.add(id(entity)) seen_trees.append(tree) return for child_field_name in get_set_entity_field_names( entity, EntityFieldType.FORWARD_EDGE): child_trees = tree.generate_child_trees( entity.get_field_as_list(child_field_name)) for child_tree in child_trees: _get_all_entity_trees_of_cls_helper( child_tree, cls, seen_ids, seen_trees, direction_checker)
def test_schemaEdgeDirectionChecker_isHigherRanked_sameRank(self): direction_checker = SchemaEdgeDirectionChecker.state_direction_checker( ) self.assertFalse( direction_checker.is_higher_ranked(StatePerson, StatePerson)) self.assertFalse( direction_checker.is_higher_ranked(StateSupervisionViolation, StateSupervisionViolation))
def test_schemaEdgeDirectionChecker_isHigherRanked_higherRank(self): direction_checker = SchemaEdgeDirectionChecker.state_direction_checker( ) self.assertTrue( direction_checker.is_higher_ranked(StatePerson, StateSentenceGroup)) self.assertTrue( direction_checker.is_higher_ranked(StateSentenceGroup, StateSupervisionViolation))
def get_all_entity_trees_of_cls(sources: Sequence[DatabaseEntity], cls: Type[DatabaseEntity]) -> List[EntityTree]: """Finds all unique entities of type |cls| in the provided |sources|, and returns their corresponding EntityTrees. """ seen_ids: Set[int] = set() seen_trees: List[EntityTree] = [] direction_checker = SchemaEdgeDirectionChecker.state_direction_checker() for source in sources: tree = EntityTree(entity=source, ancestor_chain=[]) _get_all_entity_trees_of_cls_helper(tree, cls, seen_ids, seen_trees, direction_checker) return seen_trees
def get_multiparent_classes() -> List[Type[DatabaseEntity]]: cls_list: List[Type[DatabaseEntity]] = [ schema.StateCharge, schema.StateCourtCase, schema.StateIncarcerationPeriod, schema.StateParoleDecision, schema.StateSupervisionPeriod, schema.StateSupervisionViolation, schema.StateSupervisionViolationResponse, schema.StateProgramAssignment, schema.StateAgent ] direction_checker = SchemaEdgeDirectionChecker.state_direction_checker() direction_checker.assert_sorted(cls_list) return cls_list
def _get_related_entities(self, entity: DatabaseEntity) -> List[DatabaseEntity]: """Returns list of all entities related to |entity|""" related_entities = [] for relationship_name in entity.get_relationship_property_names(): # TODO(#1145): For County schema, fix direction checker to gracefully # handle the fact that SentenceRelationship exists in the schema # but not in the entity layer. if self.get_system_level() == SystemLevel.STATE: # Skip back edges direction_checker = SchemaEdgeDirectionChecker.state_direction_checker( ) if direction_checker.is_back_edge(entity, relationship_name): continue related = getattr(entity, relationship_name) # Relationship can return either a list or a single item if isinstance(related, list): related_entities.extend(related) elif related is not None: related_entities.append(related) return related_entities
def __init__(self): direction_checker = SchemaEdgeDirectionChecker(self.CLASS_HIERARCHY, entities) super().__init__(direction_checker)
def __init__(self): super().__init__(SchemaEdgeDirectionChecker.county_direction_checker())
def expand(self, input_or_inputs): names_to_properties = self._parent_schema_class. \ get_relationship_property_names_and_properties() properties_dict = {} for property_name, property_object in names_to_properties.items(): # Get class name associated with the property property_class_name = property_object.argument.arg property_entity_class = entity_utils.get_entity_class_in_module_with_name( state_entities, property_class_name) property_schema_class = \ schema_utils.get_state_database_entity_with_name(property_class_name) direction_checker = SchemaEdgeDirectionChecker.state_direction_checker( ) is_property_forward_edge = direction_checker.is_higher_ranked( self._parent_schema_class, property_schema_class) if is_property_forward_edge: # Many-to-many relationship if property_object.secondary is not None: association_table = property_object.secondary.name entity_id_field = property_entity_class.get_class_id_name() # Extract the cross-entity relationship entities = ( input_or_inputs | f"Extract {property_name}" >> _ExtractEntityWithAssociationTable( dataset=self._dataset, entity_class=property_entity_class, unifying_id_field=self._unifying_id_field, parent_id_field=self._parent_id_field, association_table=association_table, association_table_parent_id_field=self. _parent_id_field, association_table_entity_id_field=entity_id_field, unifying_id_field_filter_set=self. _unifying_id_field_filter_set, state_code=self._state_code)) # 1-to-many relationship elif property_object.uselist: # Extract the cross-entity relationship entities = (input_or_inputs | f"Extract {property_name}" >> _ExtractEntity( dataset=self._dataset, entity_class=property_entity_class, unifying_id_field=self._unifying_id_field, parent_id_field=self._parent_id_field, unifying_id_field_filter_set=self. _unifying_id_field_filter_set, state_code=self._state_code)) # 1-to-1 relationship (from parent class perspective) else: association_table = self._parent_schema_class.__tablename__ association_table_entity_id_field = property_object.key + '_id' # Extract the cross-entity relationship entities = (input_or_inputs | f"Extract {property_name}" >> _ExtractEntityWithAssociationTable( dataset=self._dataset, entity_class=property_entity_class, unifying_id_field=self._unifying_id_field, parent_id_field=self._parent_id_field, association_table=association_table, association_table_parent_id_field=self. _parent_id_field, association_table_entity_id_field= association_table_entity_id_field, unifying_id_field_filter_set=self. _unifying_id_field_filter_set, state_code=self._state_code)) properties_dict[property_name] = entities return properties_dict
def __init__(self) -> None: super().__init__(SchemaEdgeDirectionChecker.state_direction_checker())