def test_list_embedded_reference_dereference(self): # Test dereferencing items stored in a # ListField(EmbeddedModel(ReferenceField(X))) class OtherModel(MongoModel): name = fields.CharField() class OtherRefModel(EmbeddedMongoModel): ref = fields.ReferenceField(OtherModel) class Container(MongoModel): lst = fields.EmbeddedModelListField(OtherRefModel) m1 = OtherModel('Aaron').save() m2 = OtherModel('Bob').save() container = Container(lst=[OtherRefModel(ref=m1), OtherRefModel(ref=m2)]) container.save() # Force ObjectIds. container.refresh_from_db() dereference(container) # Disable dereferencing and check for dereferenced values. with no_auto_dereference(container): self.assertEqual( container.lst[0].ref.name, 'Aaron' ) # Ensure dereferenced values during normal access. self.assertEqual(container.lst[0].ref.name, 'Aaron')
def dereference(model_instance, fields=None): """Dereference ReferenceFields on a MongoModel instance. This function is handy for dereferencing many fields at once and is more efficient than dereferencing one field at a time. :parameters: - `model_instance`: The MongoModel instance. - `fields`: An iterable of field names in "dot" notation that should be dereferenced. If left blank, all fields will be dereferenced. """ # Map of collection name --> list of ids to retrieve from the collection. reference_map = defaultdict(list) # Fields may be nested (dot-notation). Split each field into its parts. if fields: fields = [deque(field.split('.')) for field in fields] # Tell ReferenceFields not to look up their value while we scan the object. with no_auto_dereference(model_instance): _find_references(model_instance, reference_map, fields) db = _get_db(model_instance._mongometa.connection_alias) # Resolve all references, one collection at a time. # This will give us a mapping of # {collection_name --> {id --> resolved object}} document_map = _resolve_references(db, reference_map) # Traverse the object and attach resolved references where needed. _attach_objects(model_instance, document_map, fields) return model_instance
def test_select_related(self): class Comment(MongoModel): body = fields.CharField() class Post(MongoModel): body = fields.CharField() comments = fields.ListField(fields.ReferenceField(Comment)) # Create a few objects... Post(body='Nobody read this post').save() comments = [ Comment(body='This is a great post').save(), Comment(body='Horrible read').save() ] Post(body='More popular post', comments=comments).save() with no_auto_dereference(Post): posts = list(Post.objects.all()) self.assertIsNone(posts[0].comments) self.assertIsInstance(posts[1].comments[0], ObjectId) self.assertIsInstance(posts[1].comments[1], ObjectId) posts = list(Post.objects.select_related()) self.assertIsNone(posts[0].comments) self.assertEqual(posts[1].comments, comments)
def test_list_dereference(self): # Test dereferencing items stored in a ListField(ReferenceField(X)) class OtherModel(MongoModel): name = fields.CharField() class Container(MongoModel): one_to_many = fields.ListField(fields.ReferenceField(OtherModel)) m1 = OtherModel('a').save() m2 = OtherModel('b').save() container = Container([m1, m2]).save() # Implicit dereferencing. container.refresh_from_db() self.assertEqual([m1, m2], container.one_to_many) # Force ObjectIds. container.refresh_from_db() with no_auto_dereference(container): for item in container.one_to_many: self.assertIsInstance(item, ObjectId) # Explicit dereferencing. dereference(container) self.assertEqual([m1, m2], container.one_to_many)
def test_dereference_reference_not_found(self): post = Post(title='title').save() comment = Comment(body='this is a comment', post=post).save() post.delete() self.assertEqual(Post.objects.count(), 0) comment.refresh_from_db() with no_auto_dereference(comment): self.assertEqual(comment.post, 'title') dereference(comment) self.assertIsNone(comment.post)
def test_dereference_dereferenced_reference(self): class CommentContainer(MongoModel): ref = fields.ReferenceField(Comment) post = Post(title='title').save() comment = Comment(body='Comment Body', post=post).save() container = CommentContainer(ref=comment).save() with no_auto_dereference(comment), no_auto_dereference(container): comment.refresh_from_db() container.refresh_from_db() container.ref = comment self.assertEqual(container.ref.post, 'title') dereference(container) self.assertIsInstance(container.ref.post, Post) self.assertEqual(container.ref.post.title, 'title') dereference(container) self.assertIsInstance(container.ref.post, Post) self.assertEqual(container.ref.post.title, 'title')
def test_no_auto_dereference(self): game = Game('Civilization').save() badge = Badge(name='World Domination', game=game) ernie = User(fname='Ernie').save() bert = User(fname='Bert', badges=[badge], friend=ernie).save() bert.refresh_from_db() with no_auto_dereference(User): self.assertIsInstance(bert.friend, ObjectId) self.assertIsInstance(bert.badges[0].game, ObjectId) self.assertIsInstance(bert.friend, User) self.assertIsInstance(bert.badges[0].game, Game)
def test_leaf_field_dereference(self): # Test basic dereference of a ReferenceField directly in the Model. post = Post(title='This is a post.').save() comment = Comment( body='This is a comment on the post.', post=post).save() # Force ObjectIds on comment. comment.refresh_from_db() with no_auto_dereference(Comment): self.assertEqual(comment.post, post.title) dereference(comment) self.assertEqual(comment.post, post)
def full_clean(self, exclude=None): """Validate this :class:`~pymodm.MongoModel`. This method calls :meth:`~pymodm.MongoModel.clean_fields` to validate the values of all fields then :meth:`~pymodm.MongoModel.clean` to apply any custom validation rules to the model as a whole. :parameters: - `exclude`: A list of fields to exclude from validation. """ with no_auto_dereference(self): self.clean_fields(exclude=exclude) self.clean()
def test_circular_reference(self): class ReferenceA(MongoModel): ref = fields.ReferenceField('ReferenceB') class ReferenceB(MongoModel): ref = fields.ReferenceField(ReferenceA) a = ReferenceA().save() b = ReferenceB().save() a.ref = b b.ref = a a.save() b.save() self.assertEqual(a, ReferenceA.objects.first()) with no_auto_dereference(ReferenceA): self.assertEqual(b.pk, ReferenceA.objects.first().ref) self.assertEqual(b, ReferenceA.objects.select_related().first().ref)
def to_son(self): """Get this Model back as a :class:`~bson.son.SON` object. :returns: SON representing this object as a MongoDB document. """ son = SON() with no_auto_dereference(self): for field in self._mongometa.get_fields(): if field.is_undefined(self): continue raw_value = self._data.get(field.attname) if field.is_blank(raw_value): son[field.mongo_name] = raw_value else: son[field.mongo_name] = field.to_mongo(raw_value) # Add metadata about our type, so that we instantiate the right class # when retrieving from MongoDB. if not self._mongometa.final: son['_cls'] = self._mongometa.object_name return son
def test_dereference_fields(self): # Test dereferencing only specific fields. # Contrived Models that contains more than one ReferenceField at # different levels of nesting. class MultiReferenceModelEmbed(MongoModel): comments = fields.ListField(fields.ReferenceField(Comment)) posts = fields.ListField(fields.ReferenceField(Post)) class MultiReferenceModel(MongoModel): comments = fields.ListField(fields.ReferenceField(Comment)) posts = fields.ListField(fields.ReferenceField(Post)) embeds = fields.EmbeddedDocumentListField(MultiReferenceModelEmbed) post = Post(title='This is a post.').save() comments = [ Comment('comment 1', post).save(), Comment('comment 2').save() ] embed = MultiReferenceModelEmbed( comments=comments, posts=[post]) multi_ref = MultiReferenceModel( comments=comments, posts=[post], embeds=[embed]).save() # Force ObjectIds. multi_ref.refresh_from_db() dereference(multi_ref, fields=['embeds.comments', 'posts']) post.refresh_from_db() for comment in comments: comment.refresh_from_db() with no_auto_dereference(MultiReferenceModel): self.assertEqual([post], multi_ref.posts) self.assertEqual(comments, multi_ref.embeds[0].comments) # multi_ref.comments has not been dereferenced. self.assertIsInstance(multi_ref.comments[0], ObjectId)
def _test_unhashable_id(self, final_value=True): # Test that we can reference a model whose id type is unhashable # e.g. a dict, list, etc. class CardIdentity(EmbeddedMongoModel): HEARTS, DIAMONDS, SPADES, CLUBS = 0, 1, 2, 3 rank = fields.IntegerField(min_value=0, max_value=12) suit = fields.IntegerField( choices=(HEARTS, DIAMONDS, SPADES, CLUBS)) class Meta: final = final_value class Card(MongoModel): id = fields.EmbeddedModelField(CardIdentity, primary_key=True) flavor = fields.CharField() class Hand(MongoModel): cards = fields.ListField(fields.ReferenceField(Card)) cards = [ Card(CardIdentity(4, CardIdentity.CLUBS)).save(), Card(CardIdentity(12, CardIdentity.SPADES)).save() ] hand = Hand(cards).save() # test auto dereferencing hand.refresh_from_db() self.assertIsInstance(hand.cards[0], Card) self.assertEqual(hand.cards[0].id.rank, 4) self.assertIsInstance(hand.cards[1], Card) self.assertEqual(hand.cards[1].id.rank, 12) with no_auto_dereference(hand): hand.refresh_from_db() dereference(hand) self.assertIsInstance(hand.cards[0], Card) self.assertEqual(hand.cards[0].id.rank, 4) self.assertIsInstance(hand.cards[1], Card) self.assertEqual(hand.cards[1].id.rank, 12)
def test_dereference_models_with_same_id(self): class User(MongoModel): name = fields.CharField(primary_key=True) class CommentWithUser(MongoModel): body = fields.CharField() post = fields.ReferenceField(Post) user = fields.ReferenceField(User) post = Post(title='Bob').save() user = User(name='Bob').save() comment = CommentWithUser( body='this is a comment', post=post, user=user).save() comment.refresh_from_db() with no_auto_dereference(CommentWithUser): dereference(comment) self.assertIsInstance(comment.post, Post) self.assertIsInstance(comment.user, User)
def test_refresh_from_db(self): post = Post(body='This is a post.') comment = Comment(body='This is a comment on the post.', post=post) post.save() comment.save() comment.refresh_from_db() with no_auto_dereference(Comment): self.assertIsInstance(comment.post, ObjectId) # Use PyMongo to update the comment, then update the Comment instance's # view of itself. DB.comment.update_one( {'_id': comment.pk}, {'$set': {'body': 'Edited comment.'}}) # Set the comment's "post" to something else. other_post = Post(body='This is a different post.') comment.post = other_post comment.refresh_from_db(fields=['body']) self.assertEqual('Edited comment.', comment.body) # "post" field is gone, since it wasn't part of the projection. self.assertIsNone(comment.post)
def test_schedule_tasks_uses_measure_trial_tasks_correctly( self, mock_task_manager, mock_load_minimal_trial): systems = [mock_core.MockSystem(_id=ObjectId()) for _ in range(1)] image_sources = [ mock_core.MockImageSource(_id=ObjectId()) for _ in range(1) ] metrics = [mock_core.MockMetric(_id=ObjectId()) for _ in range(1)] repeats = 1 trial_results_map = make_trial_map(systems, image_sources, repeats) make_mock_get_run_system_task(mock_task_manager, trial_results_map) patch_load_minimal_trial(mock_load_minimal_trial) metric_result_ids = [] def mock_get_measure_trial_task(*_, **__): result_id = ObjectId() metric_result_ids.append(result_id) mock_task = mock.create_autospec(MeasureTrialTask) mock_task.is_finished = True mock_task.result_id = result_id return mock_task mock_task_manager.get_measure_trial_task.side_effect = mock_get_measure_trial_task subject = SimpleExperiment( name="TestSimpleExperiment", systems=systems, image_sources=image_sources, metrics=metrics, repeats=repeats, ) subject.schedule_tasks() with no_auto_dereference(SimpleExperiment): for metric_result_id in metric_result_ids: self.assertIn(metric_result_id, subject.metric_results)
def test_dereference_fields(self): # Test dereferencing only specific fields. # Contrived Models that contains more than one ReferenceField at # different levels of nesting. class MultiReferenceModelEmbed(MongoModel): comments = fields.ListField(fields.ReferenceField(Comment)) posts = fields.ListField(fields.ReferenceField(Post)) class MultiReferenceModel(MongoModel): comments = fields.ListField(fields.ReferenceField(Comment)) posts = fields.ListField(fields.ReferenceField(Post)) embeds = fields.EmbeddedDocumentListField(MultiReferenceModelEmbed) post = Post(title='This is a post.').save() comments = [ Comment('comment 1', post).save(), Comment('comment 2').save() ] embed = MultiReferenceModelEmbed(comments=comments, posts=[post]) multi_ref = MultiReferenceModel(comments=comments, posts=[post], embeds=[embed]).save() # Force ObjectIds. multi_ref.refresh_from_db() dereference(multi_ref, fields=['embeds.comments', 'posts']) post.refresh_from_db() for comment in comments: comment.refresh_from_db() with no_auto_dereference(MultiReferenceModel): self.assertEqual([post], multi_ref.posts) self.assertEqual(comments, multi_ref.embeds[0].comments) # multi_ref.comments has not been dereferenced. self.assertIsInstance(multi_ref.comments[0], ObjectId)
def test_refresh_from_db(self): post = Post(body='This is a post.') comment = Comment(body='This is a comment on the post.', post=post) post.save() comment.save() comment.refresh_from_db() with no_auto_dereference(Comment): self.assertIsInstance(comment.post, ObjectId) # Use PyMongo to update the comment, then update the Comment instance's # view of itself. DB.comment.update_one({'_id': comment.pk}, {'$set': { 'body': 'Edited comment.' }}) # Set the comment's "post" to something else. other_post = Post(body='This is a different post.') comment.post = other_post comment.refresh_from_db(fields=['body']) self.assertEqual('Edited comment.', comment.body) # "post" field is gone, since it wasn't part of the projection. self.assertIsNone(comment.post)
def test_embedded_reference_dereference(self): # Test dereferencing items stored in a # EmbeddedDocument(ReferenceField(X)) class OtherModel(MongoModel): name = fields.CharField() class OtherRefModel(EmbeddedMongoModel): ref = fields.ReferenceField(OtherModel) class Container(MongoModel): emb = fields.EmbeddedDocumentField(OtherRefModel) m1 = OtherModel('Aaron').save() container = Container(emb=OtherRefModel(ref=m1)) container.save() # Force ObjectIds. with no_auto_dereference(container): container.refresh_from_db() self.assertIsInstance(container.emb.ref, ObjectId) dereference(container) self.assertIsInstance(container.emb.ref, OtherModel) self.assertEqual(container.emb.ref.name, 'Aaron')
def test_embedded_reference_dereference(self): # Test dereferencing items stored in a # EmbeddedModel(ReferenceField(X)) class OtherModel(MongoModel): name = fields.CharField() class OtherRefModel(EmbeddedMongoModel): ref = fields.ReferenceField(OtherModel) class Container(MongoModel): emb = fields.EmbeddedModelField(OtherRefModel) m1 = OtherModel('Aaron').save() container = Container(emb=OtherRefModel(ref=m1)) container.save() # Force ObjectIds. with no_auto_dereference(container): container.refresh_from_db() self.assertIsInstance(container.emb.ref, ObjectId) dereference(container) self.assertIsInstance(container.emb.ref, OtherModel) self.assertEqual(container.emb.ref.name, 'Aaron')
def add_image_sources(self, image_sources: typing.Iterable[ImageSource]): """ Add the given image sources to this experiment if they are not already associated with it :param image_sources: :return: """ with no_auto_dereference(type(self)): if self.image_sources is None: existing_pks = set() else: existing_pks = { image_source.pk if hasattr(image_source, 'pk') else image_source for image_source in self.image_sources } new_image_sources = [ image_source for image_source in image_sources if image_source.pk not in existing_pks ] if len(new_image_sources) > 0: if self.image_sources is None: self.image_sources = new_image_sources else: self.image_sources.extend(new_image_sources)
def add_vision_systems(self, vision_systems: typing.Iterable[VisionSystem]): """ Add the given vision systems to this experiment if they are not already associated with it :param vision_systems: :return: """ with no_auto_dereference(type(self)): if self.systems is None: existing_pks = set() else: existing_pks = { system.pk if hasattr(system, 'pk') else system for system in self.systems } new_vision_systems = [ vision_system for vision_system in vision_systems if vision_system.pk not in existing_pks ] if len(new_vision_systems) > 0: if self.systems is None: self.systems = new_vision_systems else: self.systems.extend(new_vision_systems)
def load_referenced_models(self) -> None: """ Load the metric, trial, and result types so we can save the task :return: """ with no_auto_dereference(CompareTrialTask): if isinstance(self.metric, bson.ObjectId): # The metric is just an ID, we will need the model to autoload_modules(TrialComparisonMetric, [self.metric]) trials_to_load = { trial_id for trial_id in self.trial_results_1 if isinstance(trial_id, bson.ObjectId) } trials_to_load |= { trial_id for trial_id in self.trial_results_2 if isinstance(trial_id, bson.ObjectId) } if len(trials_to_load) > 0: autoload_modules(TrialResult, list(trials_to_load)) if isinstance(self.result, bson.ObjectId): # result is an id and not a model, autoload the model autoload_modules(TrialComparisonResult, [self.result])
def CreateFile(**request_handler_args): req = request_handler_args['req'] authUser(req, request_handler_args['resp'], ['createFile']) user = req.context['user'] doc = req.context['doc'] try: project = Project.objects.get( {"_id": ObjectId(request_handler_args['uri_fields']['id'])}) except InvalidId as e: raise falcon.HTTPBadRequest('Bad Request', str(e)) except Project.DoesNotExist: raise falcon.HTTPNotFound() else: with context_managers.no_auto_dereference(Project): if user._id in project.users: try: path = ProjectFS.TranslatePath(doc['path'], doc['platform'], project.paths) except (ValueError, KeyError) as e: raise falcon.HTTPBadRequest('Bad Request', e.message) try: if doc['type'] == 'file': ProjectFS.CreateFile(path) elif doc['type'] == 'folder': ProjectFS.CreateDirectory(path) else: raise falcon.HTTPBadRequest( 'Bad Request', 'Type property of FS object must be either file or folder.' ) except (IOError, OSError) as e: req.context['logger'].error({ 'action': 'createFile', 'message': str(e) }) raise falcon.HTTPInternalServerError( 'Internal Server Error', str(e)) except KeyError: raise falcon.HTTPBadRequest( 'Bad Request', 'FS Object is missing type property.') try: request_handler_args['resp'].status = falcon.HTTP_201 request_handler_args['req'].context['result'] = { 'path': path, 'security': ACL.GetACL(path), 'created': ProjectFS.GetCTime(path), 'modified': ProjectFS.GetMTime(path), 'accessed': ProjectFS.GetATime(path) } req.context['logger'].info({ 'action': 'createFile', 'message': "File/Folder '%s' was created for project %s(%s)" % (path, project.name, project._id) }) except ACL.error as e: req.context['logger'].error({ 'action': 'createFile', 'message': str(e) }) raise falcon.HTTPInternalServerError( 'Internal Server Error', str(e)) else: raise falcon.HTTPForbidden( 'Forbidden', "%s is not assigned to this project." % user.username)
def test_dereference_missed_reference_field(self): comment = Comment(body='Body Comment').save() with no_auto_dereference(comment): comment.refresh_from_db() dereference(comment) self.assertIsNone(comment.post)
def SetACL(**request_handler_args): req = request_handler_args['req'] authUser(req, request_handler_args['resp'], ['setACL']) user = req.context['user'] doc = req.context['doc'] try: project = Project.objects.get( {"_id": ObjectId(request_handler_args['uri_fields']['id'])}) except InvalidId as e: raise falcon.HTTPBadRequest('Bad Request', str(e)) except Project.DoesNotExist: raise falcon.HTTPNotFound() else: with context_managers.no_auto_dereference(Project): if user._id in project.users: try: matched_acl = ProjectFS.GetMatchACLPath( doc['path'], doc['platform'], project.acl_expanded, project.acl_expanded_depth) path = ProjectFS.TranslatePath(doc['path'], doc['platform'], project.paths) except (KeyError, ValueError) as e: raise falcon.HTTPBadRequest('Bad Request: Missing Key', e.message) else: if matched_acl != None: try: ACL.SetMatchedACL(path, matched_acl) except ACL.error as e: req.context['logger'].error({ 'action': 'setACL', 'message': e[2] }) raise falcon.HTTPInternalServerError( 'Internal Server Error', e[2]) try: request_handler_args['req'].context['result'] = { 'path': path, 'security': ACL.GetACL(path), 'created': ProjectFS.GetCTime(path), 'modified': ProjectFS.GetMTime(path), 'accessed': ProjectFS.GetATime(path) } except ACL.error as e: req.context['logger'].error({ 'action': 'setACL', 'message': str(e) }) raise falcon.HTTPInternalServerError( 'Internal Server Error', str(e)) else: request_handler_args['resp'].status = falcon.HTTP_202 try: req.context['result'] = { 'path': path, 'security': ACL.GetACL(path), 'created': ProjectFS.GetCTime(path), 'modified': ProjectFS.GetMTime(path), 'accessed': ProjectFS.GetATime(path) } req.context['logger'].info({ 'action': 'createFile', 'message': "File/Folder '%s' was created for project %s(%s)" % (path, project.name, project._id) }) except ACL.error as e: req.context['logger'].error({ 'action': 'setACL', 'message': str(e) }) raise falcon.HTTPInternalServerError( 'Internal Server Error', str(e)) else: raise falcon.HTTPForbidden( 'Forbidden', "%s is not assigned to this project." % user.username)
def get_brute_schedule(user): with no_auto_dereference(Event): users_events = list(Event.objects.raw({'user': user.id}).all()) events = {} for event in users_events: key = event.category['name'] if event.category else 'default' event.duration = event.finish_date - event.start_date if key not in events: events[key] = [event] else: events[key].append(event) sorted_keys = [] print('Made events') for _ in events: most_often_event_category = list( set(events.keys()) - set(sorted_keys))[0] for key in events: if len(events[key]) > len(events.get( most_often_event_category)) and key not in sorted_keys: most_often_event_category = key sorted_keys.append(most_often_event_category) print('Made sorted keys') for i in range(len(sorted_keys)): key = sorted_keys[i] sorted_keys[i] = {'key': key, 'size': len(events[key])} print('Updated sorted keys') sizes = np.array([item['size'] for item in sorted_keys]) mu = sizes.mean() sigma = sizes.var() / len(sizes) pdf = np.random.normal(mu, sizes.var()**0.5, len(sizes)) probs = get_probs(pdf, mu, sigma) print('Made probs') hours_in_day = 15 hours = [int(hours_in_day * prob) or 1 for prob in probs] if sum(hours) > hours_in_day: delta = sum(hours) - hours_in_day for i in range(len(hours)): if hours[i] == max(hours): hours[i] -= delta break hours_day_delta = timedelta(hours=hours_in_day) final_schedule = [] print('Before final schedule loop') for hour, item in zip(hours, sorted_keys): key_events = events[item['key']] rest_hours_for_key = timedelta(hours=hour) hours_day_delta -= rest_hours_for_key visited = [] while rest_hours_for_key.total_seconds() / 60 > 10 and len( visited) <= len(key_events): event = find_event_shorter_than_delta(key_events, rest_hours_for_key, visited) if event is None: hours_day_delta += rest_hours_for_key break else: visited.append(event.id) final_schedule.append(event) event_delta = event.finish_date - event.start_date rest_hours_for_key -= event_delta hours_day_delta += rest_hours_for_key print('After final schedule loop') if len(final_schedule) < len(users_events): visited = [] rest_event = find_event_shorter_than_delta(users_events, hours_day_delta, visited) while rest_event is not None and len(final_schedule) <= len( users_events): final_schedule.append(rest_event) visited.append(rest_event.id) event_delta = rest_event.finish_date - rest_event.start_date hours_day_delta -= event_delta rest_event = find_event_shorter_than_delta(users_events, hours_day_delta, visited) print('Added rest events') random.shuffle(final_schedule) return final_schedule
def measure_results( self, trial_results: typing.Iterable[TrialResult]) -> FrameErrorResult: """ Collect the errors TODO: Track the error introduced by a loop closure, somehow. Might need to track loop closures in the FrameResult :param trial_results: The results of several trials to aggregate :return: :rtype BenchmarkResult: """ trial_results = list(trial_results) # preload model types for the models linked to the trial results. with no_auto_dereference(SLAMTrialResult): model_ids = set(tr.system for tr in trial_results if isinstance(tr.system, bson.ObjectId)) autoload_modules(VisionSystem, list(model_ids)) model_ids = set(tr.image_source for tr in trial_results if isinstance(tr.image_source, bson.ObjectId)) autoload_modules(ImageSource, list(model_ids)) # Check if the set of trial results is valid. Loads the models. invalid_reason = check_trial_collection(trial_results) if invalid_reason is not None: return MetricResult(metric=self, trial_results=trial_results, success=False, message=invalid_reason) # Make sure we have a non-zero number of trials to measure if len(trial_results) <= 0: return MetricResult(metric=self, trial_results=trial_results, success=False, message="Cannot measure zero trials.") # Ensure the trials all have the same number of results for repeat, trial_result in enumerate(trial_results[1:]): if len(trial_result.results) != len(trial_results[0].results): return MetricResult( metric=self, trial_results=trial_results, success=False, message= f"Repeat {repeat + 1} has a different number of frames " f"({len(trial_result.results)} != {len(trial_results[0].results)})" ) # Load the system, it must be the same for all trials (see check_trial_collection) system = trial_results[0].system # Pre-load the image objects in a batch, to avoid loading them piecemeal later images = [image for _, image in trial_results[0].image_source] # Build mappings between frame result timestamps and poses for each trial timestamps_to_pose = [{ frame_result.timestamp: frame_result.pose for frame_result in trial_result.results } for trial_result in trial_results] # Choose transforms between each trajectory and the ground truth estimate_origins_and_scales = [ robust_align_trajectory_to_ground_truth( [ frame_result.estimated_pose for frame_result in trial_result.results if frame_result.estimated_pose is not None ], [ frame_result.pose for frame_result in trial_result.results if frame_result.estimated_pose is not None ], compute_scale=not bool(trial_result.has_scale), use_symmetric_scale=True) for trial_result in trial_results ] motion_scales = [1.0] * len(trial_results) for idx in range(len(trial_results)): if not trial_results[idx].has_scale: motion_scales[idx] = robust_compute_motions_scale( [ frame_result.estimated_motion for frame_result in trial_results[idx].results if frame_result.estimated_motion is not None ], [ frame_result.motion for frame_result in trial_results[idx].results if frame_result.estimated_motion is not None ], ) # Then, tally all the errors for all the computed trajectories estimate_errors = [[] for _ in range(len(trial_results))] image_columns = set() distances_lost = [[] for _ in range(len(trial_results))] times_lost = [[] for _ in range(len(trial_results))] frames_lost = [[] for _ in range(len(trial_results))] distances_found = [[] for _ in range(len(trial_results))] times_found = [[] for _ in range(len(trial_results))] frames_found = [[] for _ in range(len(trial_results))] is_tracking = [False for _ in range(len(trial_results))] tracking_frames = [0 for _ in range(len(trial_results))] tracking_distances = [0 for _ in range(len(trial_results))] prev_tracking_time = [0 for _ in range(len(trial_results))] current_tracking_time = [0 for _ in range(len(trial_results))] for frame_idx, frame_results in enumerate( zip(*(trial_result.results for trial_result in trial_results))): # Get the estimated motions and absolute poses for each trial, # And convert them to the ground truth coordinate frame using # the scale, translation and rotation we chose scaled_motions = [ tf.Transform( location=frame_results[idx].estimated_motion.location * motion_scales[idx], rotation=frame_results[idx].estimated_motion.rotation_quat( True), w_first=True) if frame_results[idx].estimated_motion is not None else None for idx in range(len(frame_results)) ] scaled_poses = [ align_point(pose=frame_results[idx].estimated_pose, shift=estimate_origins_and_scales[idx][0], rotation=estimate_origins_and_scales[idx][1], scale=estimate_origins_and_scales[idx][2]) if frame_results[idx].estimated_pose is not None else None for idx in range(len(frame_results)) ] # Find the average estimated motion for this frame across all the different trials # The average is not available for frames with only a single estimate non_null_motions = [ motion for motion in scaled_motions if motion is not None ] if len(non_null_motions) > 1: average_motion = tf.compute_average_pose(non_null_motions) else: average_motion = None # Union the image columns for all the images for all the frame results image_columns |= set( column for frame_result in frame_results for column in frame_result.image.get_columns()) for repeat_idx, frame_result in enumerate(frame_results): # Record how long the current tracking state has persisted if frame_idx <= 0: # Cannot change to or from tracking on the first frame is_tracking[repeat_idx] = (frame_result.tracking_state is TrackingState.OK) prev_tracking_time[repeat_idx] = frame_result.timestamp elif is_tracking[ repeat_idx] and frame_result.tracking_state is not TrackingState.OK: # This trial has become lost, add to the list and reset the counters frames_found[repeat_idx].append( tracking_frames[repeat_idx]) distances_found[repeat_idx].append( tracking_distances[repeat_idx]) times_found[repeat_idx].append( current_tracking_time[repeat_idx] - prev_tracking_time[repeat_idx]) tracking_frames[repeat_idx] = 0 tracking_distances[repeat_idx] = 0 prev_tracking_time[repeat_idx] = current_tracking_time[ repeat_idx] is_tracking[repeat_idx] = False elif not is_tracking[ repeat_idx] and frame_result.tracking_state is TrackingState.OK: # This trial has started to track, record how long it was lost for frames_lost[repeat_idx].append(tracking_frames[repeat_idx]) distances_lost[repeat_idx].append( tracking_distances[repeat_idx]) times_lost[repeat_idx].append( current_tracking_time[repeat_idx] - prev_tracking_time[repeat_idx]) tracking_frames[repeat_idx] = 0 tracking_distances[repeat_idx] = 0 prev_tracking_time[repeat_idx] = current_tracking_time[ repeat_idx] is_tracking[repeat_idx] = True # Update the current tracking information tracking_frames[repeat_idx] += 1 tracking_distances[repeat_idx] += np.linalg.norm( frame_result.motion.location) current_tracking_time[repeat_idx] = frame_result.timestamp # Turn loop closures into distances. We don't need to worry about origins because everything is GT frame if len(frame_result.loop_edges) > 0: loop_distances, loop_angles = compute_loop_distances_and_angles( frame_result.pose, ( timestamps_to_pose[repeat_idx][timestamp] for timestamp in frame_result.loop_edges if timestamp in timestamps_to_pose[ repeat_idx] # they should all be in there, but for safety, check )) else: loop_distances, loop_angles = [], [] # Build the frame error frame_error = make_frame_error( trial_result=trial_results[repeat_idx], frame_result=frame_result, image=images[frame_idx], system=system, repeat_index=repeat_idx, loop_distances=loop_distances, loop_angles=loop_angles, # Compute the error in the absolute estimated pose (if available) absolute_error=make_pose_error( scaled_poses[repeat_idx], # The frame_result.pose) if scaled_poses[repeat_idx] is not None else None, # Compute the error of the motion relative to the true motion relative_error=make_pose_error(scaled_motions[repeat_idx], frame_result.motion) if scaled_motions[repeat_idx] is not None else None, # Compute the error between the motion and the average estimated motion noise=make_pose_error(scaled_motions[repeat_idx], average_motion) if scaled_motions[repeat_idx] is not None and average_motion is not None else None, systemic_error=make_pose_error(average_motion, frame_result.motion) if average_motion is not None else None) estimate_errors[repeat_idx].append(frame_error) # Add any accumulated tracking information left over at the end if len(trial_results[0].results) > 0: for repeat_idx, tracking in enumerate(is_tracking): if tracking: frames_found[repeat_idx].append( tracking_frames[repeat_idx]) distances_found[repeat_idx].append( tracking_distances[repeat_idx]) times_found[repeat_idx].append( current_tracking_time[repeat_idx] - prev_tracking_time[repeat_idx]) else: frames_lost[repeat_idx].append(tracking_frames[repeat_idx]) distances_lost[repeat_idx].append( tracking_distances[repeat_idx]) times_lost[repeat_idx].append( current_tracking_time[repeat_idx] - prev_tracking_time[repeat_idx]) # Once we've tallied all the results, either succeed or fail based on the number of results. if len(estimate_errors) <= 0 or any( len(trial_errors) <= 0 for trial_errors in estimate_errors): return FrameErrorResult( metric=self, trial_results=trial_results, success=False, message="No measurable errors for these trajectories") return make_frame_error_result( metric=self, trial_results=trial_results, errors=[ TrialErrors(frame_errors=estimate_errors[repeat], frames_lost=frames_lost[repeat], frames_found=frames_found[repeat], times_lost=times_lost[repeat], times_found=times_found[repeat], distances_lost=distances_lost[repeat], distances_found=distances_found[repeat]) for repeat, trial_result in enumerate(trial_results) ])
def make_frame_error( trial_result: TrialResult, frame_result: FrameResult, image: typing.Union[None, Image], system: typing.Union[None, VisionSystem], repeat_index: int, absolute_error: typing.Union[None, PoseError], relative_error: typing.Union[None, PoseError], noise: typing.Union[None, PoseError], systemic_error: typing.Union[None, PoseError], loop_distances: typing.Iterable[float], loop_angles: typing.Iterable[float] ) -> FrameError: """ Construct a frame_error object from a context The frame error copies data from it's linked objects (like the system or image) to avoid having to dereference them later. This function makes sure that data is consistent. It takes the system and image, even though the trial result and frame result should refer to those because you're usually creating lots of FrameError objects at the same time, so you should load those objects _once_, and pass them in each time. Just because you've loaded the object doesn't mean the FrameResult has that object :param trial_result: The trial result producing this FrameError :param frame_result: The FrameResult for the specific frame this error corresponds to :param image: The specific image this error corresponds to. Will be pulled from frame_result if None. :param system: The system that produced the trial_result. Will be pulled from the trial result if None. :param repeat_index: The repeat index of the trial result, for identification within the set :param absolute_error: The error in the estimated pose, in an absolute reference frame. :param relative_error: The error in teh estimated motion, relative to the previous frame. :param noise: The error between this particular motion estimate, and the average motion estimate from all trials. :param systemic_error: The difference between the mean estimate and the true motion. Should be near zero. :param loop_distances: The distance to other images to which this image has a loop closure. Will usually be empty. :param loop_angles: The angle to other images to which this image has a loop closure. :return: A FrameError object, containing the errors, and related metadata. """ # Make sure the image we're given is the same as the one from the frame_result, without reloading it if image is None: image = frame_result.image else: with no_auto_dereference(type(frame_result)): if isinstance(frame_result.image, bson.ObjectId): image_id = frame_result.image else: image_id = frame_result.image.pk if image_id != image.pk: image = frame_result.image # Make sure the given system matches the trial result, avoiding loading it unnecessarily if system is None: system = trial_result.system else: with no_auto_dereference(type(trial_result)): if isinstance(trial_result.system, bson.ObjectId): system_id = trial_result.system else: system_id = trial_result.system.pk if system_id != system.pk: system = trial_result.system # Check that the loop distances and angles are the same loop_distances = list(loop_distances) loop_angles = list(loop_angles) if len(loop_distances) != len(loop_angles): raise ValueError("Loop distances and loop angles must always be the same length, " f"was {loop_distances} and {loop_angles}") # Read the system properties from the trial result system_properties = system.get_properties(None, trial_result.settings) image_properties = image.get_properties() return FrameError( trial_result=trial_result, image=image, repeat=repeat_index, timestamp=frame_result.timestamp, motion=frame_result.motion, processing_time=frame_result.processing_time, loop_distances=loop_distances, loop_angles=loop_angles, num_features=frame_result.num_features, num_matches=frame_result.num_matches, tracking=frame_result.tracking_state, absolute_error=absolute_error, relative_error=relative_error, noise=noise, systemic_error=systemic_error, system_properties={str(k): json_value(v) for k, v in system_properties.items()}, image_properties={str(k): json_value(v) for k, v in image_properties.items()} )