예제 #1
0
    def test_list_embedded_reference_dereference(self):
        # Test dereferencing items stored in a
        # ListField(EmbeddedModel(ReferenceField(X)))
        class OtherModel(MongoModel):
            name = fields.CharField()

        class OtherRefModel(EmbeddedMongoModel):
            ref = fields.ReferenceField(OtherModel)

        class Container(MongoModel):
            lst = fields.EmbeddedModelListField(OtherRefModel)

        m1 = OtherModel('Aaron').save()
        m2 = OtherModel('Bob').save()

        container = Container(lst=[OtherRefModel(ref=m1),
                                   OtherRefModel(ref=m2)])
        container.save()

        # Force ObjectIds.
        container.refresh_from_db()
        dereference(container)

        # Disable dereferencing and check for dereferenced values.
        with no_auto_dereference(container):
            self.assertEqual(
                container.lst[0].ref.name, 'Aaron'

            )
        # Ensure dereferenced values during normal access.
        self.assertEqual(container.lst[0].ref.name, 'Aaron')
예제 #2
0
def dereference(model_instance, fields=None):
    """Dereference ReferenceFields on a MongoModel instance.

    This function is handy for dereferencing many fields at once and is more
    efficient than dereferencing one field at a time.

    :parameters:
      - `model_instance`: The MongoModel instance.
      - `fields`: An iterable of field names in "dot" notation that
        should be dereferenced. If left blank, all fields will be dereferenced.
    """
    # Map of collection name --> list of ids to retrieve from the collection.
    reference_map = defaultdict(list)

    # Fields may be nested (dot-notation). Split each field into its parts.
    if fields:
        fields = [deque(field.split('.')) for field in fields]

    # Tell ReferenceFields not to look up their value while we scan the object.
    with no_auto_dereference(model_instance):
        _find_references(model_instance, reference_map, fields)

        db = _get_db(model_instance._mongometa.connection_alias)
        # Resolve all references, one collection at a time.
        # This will give us a mapping of
        # {collection_name --> {id --> resolved object}}
        document_map = _resolve_references(db, reference_map)

        # Traverse the object and attach resolved references where needed.
        _attach_objects(model_instance, document_map, fields)

    return model_instance
예제 #3
0
    def test_select_related(self):
        class Comment(MongoModel):
            body = fields.CharField()

        class Post(MongoModel):
            body = fields.CharField()
            comments = fields.ListField(fields.ReferenceField(Comment))

        # Create a few objects...
        Post(body='Nobody read this post').save()
        comments = [
            Comment(body='This is a great post').save(),
            Comment(body='Horrible read').save()
        ]
        Post(body='More popular post', comments=comments).save()

        with no_auto_dereference(Post):
            posts = list(Post.objects.all())
            self.assertIsNone(posts[0].comments)
            self.assertIsInstance(posts[1].comments[0], ObjectId)
            self.assertIsInstance(posts[1].comments[1], ObjectId)

            posts = list(Post.objects.select_related())
            self.assertIsNone(posts[0].comments)
            self.assertEqual(posts[1].comments, comments)
예제 #4
0
    def test_list_dereference(self):
        # Test dereferencing items stored in a ListField(ReferenceField(X))
        class OtherModel(MongoModel):
            name = fields.CharField()

        class Container(MongoModel):
            one_to_many = fields.ListField(fields.ReferenceField(OtherModel))

        m1 = OtherModel('a').save()
        m2 = OtherModel('b').save()
        container = Container([m1, m2]).save()

        # Implicit dereferencing.
        container.refresh_from_db()
        self.assertEqual([m1, m2], container.one_to_many)

        # Force ObjectIds.
        container.refresh_from_db()
        with no_auto_dereference(container):
            for item in container.one_to_many:
                self.assertIsInstance(item, ObjectId)

        # Explicit dereferencing.
        dereference(container)
        self.assertEqual([m1, m2], container.one_to_many)
예제 #5
0
 def test_dereference_reference_not_found(self):
     post = Post(title='title').save()
     comment = Comment(body='this is a comment', post=post).save()
     post.delete()
     self.assertEqual(Post.objects.count(), 0)
     comment.refresh_from_db()
     with no_auto_dereference(comment):
         self.assertEqual(comment.post, 'title')
         dereference(comment)
         self.assertIsNone(comment.post)
예제 #6
0
    def test_dereference_dereferenced_reference(self):
        class CommentContainer(MongoModel):
            ref = fields.ReferenceField(Comment)

        post = Post(title='title').save()
        comment = Comment(body='Comment Body', post=post).save()

        container = CommentContainer(ref=comment).save()

        with no_auto_dereference(comment), no_auto_dereference(container):
            comment.refresh_from_db()
            container.refresh_from_db()
            container.ref = comment
            self.assertEqual(container.ref.post, 'title')
            dereference(container)
            self.assertIsInstance(container.ref.post, Post)
            self.assertEqual(container.ref.post.title, 'title')
            dereference(container)
            self.assertIsInstance(container.ref.post, Post)
            self.assertEqual(container.ref.post.title, 'title')
예제 #7
0
    def test_no_auto_dereference(self):
        game = Game('Civilization').save()
        badge = Badge(name='World Domination', game=game)
        ernie = User(fname='Ernie').save()
        bert = User(fname='Bert', badges=[badge], friend=ernie).save()

        bert.refresh_from_db()

        with no_auto_dereference(User):
            self.assertIsInstance(bert.friend, ObjectId)
            self.assertIsInstance(bert.badges[0].game, ObjectId)
        self.assertIsInstance(bert.friend, User)
        self.assertIsInstance(bert.badges[0].game, Game)
예제 #8
0
    def test_leaf_field_dereference(self):
        # Test basic dereference of a ReferenceField directly in the Model.
        post = Post(title='This is a post.').save()
        comment = Comment(
            body='This is a comment on the post.', post=post).save()

        # Force ObjectIds on comment.
        comment.refresh_from_db()
        with no_auto_dereference(Comment):
            self.assertEqual(comment.post, post.title)

            dereference(comment)
            self.assertEqual(comment.post, post)
예제 #9
0
    def full_clean(self, exclude=None):
        """Validate this :class:`~pymodm.MongoModel`.

        This method calls :meth:`~pymodm.MongoModel.clean_fields` to validate
        the values of all fields then :meth:`~pymodm.MongoModel.clean` to
        apply any custom validation rules to the model as a whole.

        :parameters:
          - `exclude`: A list of fields to exclude from validation.

        """
        with no_auto_dereference(self):
            self.clean_fields(exclude=exclude)
        self.clean()
예제 #10
0
    def test_circular_reference(self):
        class ReferenceA(MongoModel):
            ref = fields.ReferenceField('ReferenceB')

        class ReferenceB(MongoModel):
            ref = fields.ReferenceField(ReferenceA)

        a = ReferenceA().save()
        b = ReferenceB().save()
        a.ref = b
        b.ref = a
        a.save()
        b.save()

        self.assertEqual(a, ReferenceA.objects.first())
        with no_auto_dereference(ReferenceA):
            self.assertEqual(b.pk, ReferenceA.objects.first().ref)
        self.assertEqual(b, ReferenceA.objects.select_related().first().ref)
예제 #11
0
    def to_son(self):
        """Get this Model back as a :class:`~bson.son.SON` object.

        :returns: SON representing this object as a MongoDB document.

        """
        son = SON()
        with no_auto_dereference(self):
            for field in self._mongometa.get_fields():
                if field.is_undefined(self):
                    continue
                raw_value = self._data.get(field.attname)
                if field.is_blank(raw_value):
                    son[field.mongo_name] = raw_value
                else:
                    son[field.mongo_name] = field.to_mongo(raw_value)
        # Add metadata about our type, so that we instantiate the right class
        # when retrieving from MongoDB.
        if not self._mongometa.final:
            son['_cls'] = self._mongometa.object_name
        return son
예제 #12
0
    def test_dereference_fields(self):
        # Test dereferencing only specific fields.

        # Contrived Models that contains more than one ReferenceField at
        # different levels of nesting.
        class MultiReferenceModelEmbed(MongoModel):
            comments = fields.ListField(fields.ReferenceField(Comment))
            posts = fields.ListField(fields.ReferenceField(Post))

        class MultiReferenceModel(MongoModel):
            comments = fields.ListField(fields.ReferenceField(Comment))
            posts = fields.ListField(fields.ReferenceField(Post))
            embeds = fields.EmbeddedDocumentListField(MultiReferenceModelEmbed)

        post = Post(title='This is a post.').save()
        comments = [
            Comment('comment 1', post).save(),
            Comment('comment 2').save()
        ]
        embed = MultiReferenceModelEmbed(
            comments=comments,
            posts=[post])
        multi_ref = MultiReferenceModel(
            comments=comments,
            posts=[post],
            embeds=[embed]).save()

        # Force ObjectIds.
        multi_ref.refresh_from_db()

        dereference(multi_ref, fields=['embeds.comments', 'posts'])

        post.refresh_from_db()
        for comment in comments:
            comment.refresh_from_db()
        with no_auto_dereference(MultiReferenceModel):
            self.assertEqual([post], multi_ref.posts)
            self.assertEqual(comments, multi_ref.embeds[0].comments)
            # multi_ref.comments has not been dereferenced.
            self.assertIsInstance(multi_ref.comments[0], ObjectId)
예제 #13
0
    def _test_unhashable_id(self, final_value=True):
        # Test that we can reference a model whose id type is unhashable
        # e.g. a dict, list, etc.
        class CardIdentity(EmbeddedMongoModel):
            HEARTS, DIAMONDS, SPADES, CLUBS = 0, 1, 2, 3

            rank = fields.IntegerField(min_value=0, max_value=12)
            suit = fields.IntegerField(
                choices=(HEARTS, DIAMONDS, SPADES, CLUBS))

            class Meta:
                final = final_value

        class Card(MongoModel):
            id = fields.EmbeddedModelField(CardIdentity, primary_key=True)
            flavor = fields.CharField()

        class Hand(MongoModel):
            cards = fields.ListField(fields.ReferenceField(Card))

        cards = [
            Card(CardIdentity(4, CardIdentity.CLUBS)).save(),
            Card(CardIdentity(12, CardIdentity.SPADES)).save()
        ]
        hand = Hand(cards).save()

        # test auto dereferencing
        hand.refresh_from_db()
        self.assertIsInstance(hand.cards[0], Card)
        self.assertEqual(hand.cards[0].id.rank, 4)
        self.assertIsInstance(hand.cards[1], Card)
        self.assertEqual(hand.cards[1].id.rank, 12)

        with no_auto_dereference(hand):
            hand.refresh_from_db()
            dereference(hand)
            self.assertIsInstance(hand.cards[0], Card)
            self.assertEqual(hand.cards[0].id.rank, 4)
            self.assertIsInstance(hand.cards[1], Card)
            self.assertEqual(hand.cards[1].id.rank, 12)
예제 #14
0
    def test_dereference_models_with_same_id(self):
        class User(MongoModel):
            name = fields.CharField(primary_key=True)

        class CommentWithUser(MongoModel):
            body = fields.CharField()
            post = fields.ReferenceField(Post)
            user = fields.ReferenceField(User)

        post = Post(title='Bob').save()
        user = User(name='Bob').save()

        comment = CommentWithUser(
            body='this is a comment',
            post=post,
            user=user).save()

        comment.refresh_from_db()
        with no_auto_dereference(CommentWithUser):
            dereference(comment)
            self.assertIsInstance(comment.post, Post)
            self.assertIsInstance(comment.user, User)
예제 #15
0
    def test_refresh_from_db(self):
        post = Post(body='This is a post.')
        comment = Comment(body='This is a comment on the post.',
                          post=post)
        post.save()
        comment.save()

        comment.refresh_from_db()
        with no_auto_dereference(Comment):
            self.assertIsInstance(comment.post, ObjectId)

        # Use PyMongo to update the comment, then update the Comment instance's
        # view of itself.

        DB.comment.update_one(
            {'_id': comment.pk}, {'$set': {'body': 'Edited comment.'}})
        # Set the comment's "post" to something else.
        other_post = Post(body='This is a different post.')
        comment.post = other_post
        comment.refresh_from_db(fields=['body'])
        self.assertEqual('Edited comment.', comment.body)
        # "post" field is gone, since it wasn't part of the projection.
        self.assertIsNone(comment.post)
예제 #16
0
    def test_schedule_tasks_uses_measure_trial_tasks_correctly(
            self, mock_task_manager, mock_load_minimal_trial):
        systems = [mock_core.MockSystem(_id=ObjectId()) for _ in range(1)]
        image_sources = [
            mock_core.MockImageSource(_id=ObjectId()) for _ in range(1)
        ]
        metrics = [mock_core.MockMetric(_id=ObjectId()) for _ in range(1)]
        repeats = 1

        trial_results_map = make_trial_map(systems, image_sources, repeats)
        make_mock_get_run_system_task(mock_task_manager, trial_results_map)
        patch_load_minimal_trial(mock_load_minimal_trial)

        metric_result_ids = []

        def mock_get_measure_trial_task(*_, **__):
            result_id = ObjectId()
            metric_result_ids.append(result_id)
            mock_task = mock.create_autospec(MeasureTrialTask)
            mock_task.is_finished = True
            mock_task.result_id = result_id
            return mock_task

        mock_task_manager.get_measure_trial_task.side_effect = mock_get_measure_trial_task

        subject = SimpleExperiment(
            name="TestSimpleExperiment",
            systems=systems,
            image_sources=image_sources,
            metrics=metrics,
            repeats=repeats,
        )
        subject.schedule_tasks()

        with no_auto_dereference(SimpleExperiment):
            for metric_result_id in metric_result_ids:
                self.assertIn(metric_result_id, subject.metric_results)
예제 #17
0
    def test_dereference_fields(self):
        # Test dereferencing only specific fields.

        # Contrived Models that contains more than one ReferenceField at
        # different levels of nesting.
        class MultiReferenceModelEmbed(MongoModel):
            comments = fields.ListField(fields.ReferenceField(Comment))
            posts = fields.ListField(fields.ReferenceField(Post))

        class MultiReferenceModel(MongoModel):
            comments = fields.ListField(fields.ReferenceField(Comment))
            posts = fields.ListField(fields.ReferenceField(Post))
            embeds = fields.EmbeddedDocumentListField(MultiReferenceModelEmbed)

        post = Post(title='This is a post.').save()
        comments = [
            Comment('comment 1', post).save(),
            Comment('comment 2').save()
        ]
        embed = MultiReferenceModelEmbed(comments=comments, posts=[post])
        multi_ref = MultiReferenceModel(comments=comments,
                                        posts=[post],
                                        embeds=[embed]).save()

        # Force ObjectIds.
        multi_ref.refresh_from_db()

        dereference(multi_ref, fields=['embeds.comments', 'posts'])

        post.refresh_from_db()
        for comment in comments:
            comment.refresh_from_db()
        with no_auto_dereference(MultiReferenceModel):
            self.assertEqual([post], multi_ref.posts)
            self.assertEqual(comments, multi_ref.embeds[0].comments)
            # multi_ref.comments has not been dereferenced.
            self.assertIsInstance(multi_ref.comments[0], ObjectId)
예제 #18
0
    def test_refresh_from_db(self):
        post = Post(body='This is a post.')
        comment = Comment(body='This is a comment on the post.', post=post)
        post.save()
        comment.save()

        comment.refresh_from_db()
        with no_auto_dereference(Comment):
            self.assertIsInstance(comment.post, ObjectId)

        # Use PyMongo to update the comment, then update the Comment instance's
        # view of itself.

        DB.comment.update_one({'_id': comment.pk},
                              {'$set': {
                                  'body': 'Edited comment.'
                              }})
        # Set the comment's "post" to something else.
        other_post = Post(body='This is a different post.')
        comment.post = other_post
        comment.refresh_from_db(fields=['body'])
        self.assertEqual('Edited comment.', comment.body)
        # "post" field is gone, since it wasn't part of the projection.
        self.assertIsNone(comment.post)
예제 #19
0
    def test_embedded_reference_dereference(self):
        # Test dereferencing items stored in a
        # EmbeddedDocument(ReferenceField(X))
        class OtherModel(MongoModel):
            name = fields.CharField()

        class OtherRefModel(EmbeddedMongoModel):
            ref = fields.ReferenceField(OtherModel)

        class Container(MongoModel):
            emb = fields.EmbeddedDocumentField(OtherRefModel)

        m1 = OtherModel('Aaron').save()

        container = Container(emb=OtherRefModel(ref=m1))
        container.save()

        # Force ObjectIds.
        with no_auto_dereference(container):
            container.refresh_from_db()
            self.assertIsInstance(container.emb.ref, ObjectId)
            dereference(container)
            self.assertIsInstance(container.emb.ref, OtherModel)
            self.assertEqual(container.emb.ref.name, 'Aaron')
예제 #20
0
    def test_embedded_reference_dereference(self):
        # Test dereferencing items stored in a
        # EmbeddedModel(ReferenceField(X))
        class OtherModel(MongoModel):
            name = fields.CharField()

        class OtherRefModel(EmbeddedMongoModel):
            ref = fields.ReferenceField(OtherModel)

        class Container(MongoModel):
            emb = fields.EmbeddedModelField(OtherRefModel)

        m1 = OtherModel('Aaron').save()

        container = Container(emb=OtherRefModel(ref=m1))
        container.save()

        # Force ObjectIds.
        with no_auto_dereference(container):
            container.refresh_from_db()
            self.assertIsInstance(container.emb.ref, ObjectId)
            dereference(container)
            self.assertIsInstance(container.emb.ref, OtherModel)
            self.assertEqual(container.emb.ref.name, 'Aaron')
예제 #21
0
 def add_image_sources(self, image_sources: typing.Iterable[ImageSource]):
     """
     Add the given image sources to this experiment if they are not already associated with it
     :param image_sources:
     :return:
     """
     with no_auto_dereference(type(self)):
         if self.image_sources is None:
             existing_pks = set()
         else:
             existing_pks = {
                 image_source.pk
                 if hasattr(image_source, 'pk') else image_source
                 for image_source in self.image_sources
             }
         new_image_sources = [
             image_source for image_source in image_sources
             if image_source.pk not in existing_pks
         ]
         if len(new_image_sources) > 0:
             if self.image_sources is None:
                 self.image_sources = new_image_sources
             else:
                 self.image_sources.extend(new_image_sources)
예제 #22
0
 def add_vision_systems(self,
                        vision_systems: typing.Iterable[VisionSystem]):
     """
     Add the given vision systems to this experiment if they are not already associated with it
     :param vision_systems:
     :return:
     """
     with no_auto_dereference(type(self)):
         if self.systems is None:
             existing_pks = set()
         else:
             existing_pks = {
                 system.pk if hasattr(system, 'pk') else system
                 for system in self.systems
             }
         new_vision_systems = [
             vision_system for vision_system in vision_systems
             if vision_system.pk not in existing_pks
         ]
         if len(new_vision_systems) > 0:
             if self.systems is None:
                 self.systems = new_vision_systems
             else:
                 self.systems.extend(new_vision_systems)
예제 #23
0
 def load_referenced_models(self) -> None:
     """
     Load the metric, trial, and result types so we can save the task
     :return:
     """
     with no_auto_dereference(CompareTrialTask):
         if isinstance(self.metric, bson.ObjectId):
             # The metric is just an ID, we will need the model to
             autoload_modules(TrialComparisonMetric, [self.metric])
         trials_to_load = {
             trial_id
             for trial_id in self.trial_results_1
             if isinstance(trial_id, bson.ObjectId)
         }
         trials_to_load |= {
             trial_id
             for trial_id in self.trial_results_2
             if isinstance(trial_id, bson.ObjectId)
         }
         if len(trials_to_load) > 0:
             autoload_modules(TrialResult, list(trials_to_load))
         if isinstance(self.result, bson.ObjectId):
             # result is an id and not a model, autoload the model
             autoload_modules(TrialComparisonResult, [self.result])
예제 #24
0
파일: fsapi.py 프로젝트: r0b0tAnthony/fsapi
def CreateFile(**request_handler_args):
    req = request_handler_args['req']
    authUser(req, request_handler_args['resp'], ['createFile'])
    user = req.context['user']
    doc = req.context['doc']
    try:
        project = Project.objects.get(
            {"_id": ObjectId(request_handler_args['uri_fields']['id'])})
    except InvalidId as e:
        raise falcon.HTTPBadRequest('Bad Request', str(e))
    except Project.DoesNotExist:
        raise falcon.HTTPNotFound()
    else:
        with context_managers.no_auto_dereference(Project):
            if user._id in project.users:
                try:
                    path = ProjectFS.TranslatePath(doc['path'],
                                                   doc['platform'],
                                                   project.paths)
                except (ValueError, KeyError) as e:
                    raise falcon.HTTPBadRequest('Bad Request', e.message)
                try:
                    if doc['type'] == 'file':
                        ProjectFS.CreateFile(path)
                    elif doc['type'] == 'folder':
                        ProjectFS.CreateDirectory(path)
                    else:
                        raise falcon.HTTPBadRequest(
                            'Bad Request',
                            'Type property of FS object must be either file or folder.'
                        )
                except (IOError, OSError) as e:
                    req.context['logger'].error({
                        'action': 'createFile',
                        'message': str(e)
                    })
                    raise falcon.HTTPInternalServerError(
                        'Internal Server Error', str(e))
                except KeyError:
                    raise falcon.HTTPBadRequest(
                        'Bad Request', 'FS Object is missing type property.')
                try:
                    request_handler_args['resp'].status = falcon.HTTP_201
                    request_handler_args['req'].context['result'] = {
                        'path': path,
                        'security': ACL.GetACL(path),
                        'created': ProjectFS.GetCTime(path),
                        'modified': ProjectFS.GetMTime(path),
                        'accessed': ProjectFS.GetATime(path)
                    }
                    req.context['logger'].info({
                        'action':
                        'createFile',
                        'message':
                        "File/Folder '%s' was created for project %s(%s)" %
                        (path, project.name, project._id)
                    })
                except ACL.error as e:
                    req.context['logger'].error({
                        'action': 'createFile',
                        'message': str(e)
                    })
                    raise falcon.HTTPInternalServerError(
                        'Internal Server Error', str(e))
            else:
                raise falcon.HTTPForbidden(
                    'Forbidden',
                    "%s is not assigned to this project." % user.username)
예제 #25
0
 def test_dereference_missed_reference_field(self):
     comment = Comment(body='Body Comment').save()
     with no_auto_dereference(comment):
         comment.refresh_from_db()
         dereference(comment)
         self.assertIsNone(comment.post)
예제 #26
0
파일: fsapi.py 프로젝트: r0b0tAnthony/fsapi
def SetACL(**request_handler_args):
    req = request_handler_args['req']
    authUser(req, request_handler_args['resp'], ['setACL'])
    user = req.context['user']
    doc = req.context['doc']
    try:
        project = Project.objects.get(
            {"_id": ObjectId(request_handler_args['uri_fields']['id'])})
    except InvalidId as e:
        raise falcon.HTTPBadRequest('Bad Request', str(e))
    except Project.DoesNotExist:
        raise falcon.HTTPNotFound()
    else:
        with context_managers.no_auto_dereference(Project):
            if user._id in project.users:

                try:
                    matched_acl = ProjectFS.GetMatchACLPath(
                        doc['path'], doc['platform'], project.acl_expanded,
                        project.acl_expanded_depth)
                    path = ProjectFS.TranslatePath(doc['path'],
                                                   doc['platform'],
                                                   project.paths)
                except (KeyError, ValueError) as e:
                    raise falcon.HTTPBadRequest('Bad Request: Missing Key',
                                                e.message)
                else:
                    if matched_acl != None:
                        try:
                            ACL.SetMatchedACL(path, matched_acl)
                        except ACL.error as e:
                            req.context['logger'].error({
                                'action': 'setACL',
                                'message': e[2]
                            })
                            raise falcon.HTTPInternalServerError(
                                'Internal Server Error', e[2])

                        try:
                            request_handler_args['req'].context['result'] = {
                                'path': path,
                                'security': ACL.GetACL(path),
                                'created': ProjectFS.GetCTime(path),
                                'modified': ProjectFS.GetMTime(path),
                                'accessed': ProjectFS.GetATime(path)
                            }
                        except ACL.error as e:
                            req.context['logger'].error({
                                'action': 'setACL',
                                'message': str(e)
                            })
                            raise falcon.HTTPInternalServerError(
                                'Internal Server Error', str(e))
                    else:
                        request_handler_args['resp'].status = falcon.HTTP_202
                        try:
                            req.context['result'] = {
                                'path': path,
                                'security': ACL.GetACL(path),
                                'created': ProjectFS.GetCTime(path),
                                'modified': ProjectFS.GetMTime(path),
                                'accessed': ProjectFS.GetATime(path)
                            }
                            req.context['logger'].info({
                                'action':
                                'createFile',
                                'message':
                                "File/Folder '%s' was created for project %s(%s)"
                                % (path, project.name, project._id)
                            })
                        except ACL.error as e:
                            req.context['logger'].error({
                                'action': 'setACL',
                                'message': str(e)
                            })
                            raise falcon.HTTPInternalServerError(
                                'Internal Server Error', str(e))
            else:
                raise falcon.HTTPForbidden(
                    'Forbidden',
                    "%s is not assigned to this project." % user.username)
예제 #27
0
def get_brute_schedule(user):
    with no_auto_dereference(Event):
        users_events = list(Event.objects.raw({'user': user.id}).all())
        events = {}
        for event in users_events:
            key = event.category['name'] if event.category else 'default'
            event.duration = event.finish_date - event.start_date
            if key not in events:
                events[key] = [event]
            else:
                events[key].append(event)

    sorted_keys = []
    print('Made events')
    for _ in events:
        most_often_event_category = list(
            set(events.keys()) - set(sorted_keys))[0]
        for key in events:
            if len(events[key]) > len(events.get(
                    most_often_event_category)) and key not in sorted_keys:
                most_often_event_category = key
        sorted_keys.append(most_often_event_category)

    print('Made sorted keys')

    for i in range(len(sorted_keys)):
        key = sorted_keys[i]
        sorted_keys[i] = {'key': key, 'size': len(events[key])}

    print('Updated sorted keys')

    sizes = np.array([item['size'] for item in sorted_keys])
    mu = sizes.mean()
    sigma = sizes.var() / len(sizes)
    pdf = np.random.normal(mu, sizes.var()**0.5, len(sizes))
    probs = get_probs(pdf, mu, sigma)

    print('Made probs')

    hours_in_day = 15
    hours = [int(hours_in_day * prob) or 1 for prob in probs]
    if sum(hours) > hours_in_day:
        delta = sum(hours) - hours_in_day
        for i in range(len(hours)):
            if hours[i] == max(hours):
                hours[i] -= delta
                break
    hours_day_delta = timedelta(hours=hours_in_day)
    final_schedule = []

    print('Before final schedule loop')
    for hour, item in zip(hours, sorted_keys):
        key_events = events[item['key']]
        rest_hours_for_key = timedelta(hours=hour)
        hours_day_delta -= rest_hours_for_key
        visited = []
        while rest_hours_for_key.total_seconds() / 60 > 10 and len(
                visited) <= len(key_events):
            event = find_event_shorter_than_delta(key_events,
                                                  rest_hours_for_key, visited)
            if event is None:
                hours_day_delta += rest_hours_for_key
                break
            else:
                visited.append(event.id)
                final_schedule.append(event)
                event_delta = event.finish_date - event.start_date
                rest_hours_for_key -= event_delta
        hours_day_delta += rest_hours_for_key

    print('After final schedule loop')
    if len(final_schedule) < len(users_events):
        visited = []
        rest_event = find_event_shorter_than_delta(users_events,
                                                   hours_day_delta, visited)
        while rest_event is not None and len(final_schedule) <= len(
                users_events):
            final_schedule.append(rest_event)
            visited.append(rest_event.id)
            event_delta = rest_event.finish_date - rest_event.start_date
            hours_day_delta -= event_delta
            rest_event = find_event_shorter_than_delta(users_events,
                                                       hours_day_delta,
                                                       visited)

    print('Added rest events')
    random.shuffle(final_schedule)
    return final_schedule
예제 #28
0
 def test_dereference_missed_reference_field(self):
     comment = Comment(body='Body Comment').save()
     with no_auto_dereference(comment):
         comment.refresh_from_db()
         dereference(comment)
         self.assertIsNone(comment.post)
예제 #29
0
    def measure_results(
            self,
            trial_results: typing.Iterable[TrialResult]) -> FrameErrorResult:
        """
        Collect the errors
        TODO: Track the error introduced by a loop closure, somehow.
        Might need to track loop closures in the FrameResult
        :param trial_results: The results of several trials to aggregate
        :return:
        :rtype BenchmarkResult:
        """
        trial_results = list(trial_results)

        # preload model types for the models linked to the trial results.
        with no_auto_dereference(SLAMTrialResult):
            model_ids = set(tr.system for tr in trial_results
                            if isinstance(tr.system, bson.ObjectId))
            autoload_modules(VisionSystem, list(model_ids))
            model_ids = set(tr.image_source for tr in trial_results
                            if isinstance(tr.image_source, bson.ObjectId))
            autoload_modules(ImageSource, list(model_ids))

        # Check if the set of trial results is valid. Loads the models.
        invalid_reason = check_trial_collection(trial_results)
        if invalid_reason is not None:
            return MetricResult(metric=self,
                                trial_results=trial_results,
                                success=False,
                                message=invalid_reason)

        # Make sure we have a non-zero number of trials to measure
        if len(trial_results) <= 0:
            return MetricResult(metric=self,
                                trial_results=trial_results,
                                success=False,
                                message="Cannot measure zero trials.")

        # Ensure the trials all have the same number of results
        for repeat, trial_result in enumerate(trial_results[1:]):
            if len(trial_result.results) != len(trial_results[0].results):
                return MetricResult(
                    metric=self,
                    trial_results=trial_results,
                    success=False,
                    message=
                    f"Repeat {repeat + 1} has a different number of frames "
                    f"({len(trial_result.results)} != {len(trial_results[0].results)})"
                )

        # Load the system, it must be the same for all trials (see check_trial_collection)
        system = trial_results[0].system

        # Pre-load the image objects in a batch, to avoid loading them piecemeal later
        images = [image for _, image in trial_results[0].image_source]

        # Build mappings between frame result timestamps and poses for each trial
        timestamps_to_pose = [{
            frame_result.timestamp: frame_result.pose
            for frame_result in trial_result.results
        } for trial_result in trial_results]

        # Choose transforms between each trajectory and the ground truth
        estimate_origins_and_scales = [
            robust_align_trajectory_to_ground_truth(
                [
                    frame_result.estimated_pose
                    for frame_result in trial_result.results
                    if frame_result.estimated_pose is not None
                ], [
                    frame_result.pose for frame_result in trial_result.results
                    if frame_result.estimated_pose is not None
                ],
                compute_scale=not bool(trial_result.has_scale),
                use_symmetric_scale=True) for trial_result in trial_results
        ]
        motion_scales = [1.0] * len(trial_results)
        for idx in range(len(trial_results)):
            if not trial_results[idx].has_scale:
                motion_scales[idx] = robust_compute_motions_scale(
                    [
                        frame_result.estimated_motion
                        for frame_result in trial_results[idx].results
                        if frame_result.estimated_motion is not None
                    ],
                    [
                        frame_result.motion
                        for frame_result in trial_results[idx].results
                        if frame_result.estimated_motion is not None
                    ],
                )

        # Then, tally all the errors for all the computed trajectories
        estimate_errors = [[] for _ in range(len(trial_results))]
        image_columns = set()
        distances_lost = [[] for _ in range(len(trial_results))]
        times_lost = [[] for _ in range(len(trial_results))]
        frames_lost = [[] for _ in range(len(trial_results))]
        distances_found = [[] for _ in range(len(trial_results))]
        times_found = [[] for _ in range(len(trial_results))]
        frames_found = [[] for _ in range(len(trial_results))]

        is_tracking = [False for _ in range(len(trial_results))]
        tracking_frames = [0 for _ in range(len(trial_results))]
        tracking_distances = [0 for _ in range(len(trial_results))]
        prev_tracking_time = [0 for _ in range(len(trial_results))]
        current_tracking_time = [0 for _ in range(len(trial_results))]

        for frame_idx, frame_results in enumerate(
                zip(*(trial_result.results
                      for trial_result in trial_results))):
            # Get the estimated motions and absolute poses for each trial,
            # And convert them to the ground truth coordinate frame using
            # the scale, translation and rotation we chose
            scaled_motions = [
                tf.Transform(
                    location=frame_results[idx].estimated_motion.location *
                    motion_scales[idx],
                    rotation=frame_results[idx].estimated_motion.rotation_quat(
                        True),
                    w_first=True)
                if frame_results[idx].estimated_motion is not None else None
                for idx in range(len(frame_results))
            ]
            scaled_poses = [
                align_point(pose=frame_results[idx].estimated_pose,
                            shift=estimate_origins_and_scales[idx][0],
                            rotation=estimate_origins_and_scales[idx][1],
                            scale=estimate_origins_and_scales[idx][2])
                if frame_results[idx].estimated_pose is not None else None
                for idx in range(len(frame_results))
            ]

            # Find the average estimated motion for this frame across all the different trials
            # The average is not available for frames with only a single estimate
            non_null_motions = [
                motion for motion in scaled_motions if motion is not None
            ]
            if len(non_null_motions) > 1:
                average_motion = tf.compute_average_pose(non_null_motions)
            else:
                average_motion = None

            # Union the image columns for all the images for all the frame results
            image_columns |= set(
                column for frame_result in frame_results
                for column in frame_result.image.get_columns())

            for repeat_idx, frame_result in enumerate(frame_results):

                # Record how long the current tracking state has persisted
                if frame_idx <= 0:
                    # Cannot change to or from tracking on the first frame
                    is_tracking[repeat_idx] = (frame_result.tracking_state is
                                               TrackingState.OK)
                    prev_tracking_time[repeat_idx] = frame_result.timestamp
                elif is_tracking[
                        repeat_idx] and frame_result.tracking_state is not TrackingState.OK:
                    # This trial has become lost, add to the list and reset the counters
                    frames_found[repeat_idx].append(
                        tracking_frames[repeat_idx])
                    distances_found[repeat_idx].append(
                        tracking_distances[repeat_idx])
                    times_found[repeat_idx].append(
                        current_tracking_time[repeat_idx] -
                        prev_tracking_time[repeat_idx])
                    tracking_frames[repeat_idx] = 0
                    tracking_distances[repeat_idx] = 0
                    prev_tracking_time[repeat_idx] = current_tracking_time[
                        repeat_idx]
                    is_tracking[repeat_idx] = False
                elif not is_tracking[
                        repeat_idx] and frame_result.tracking_state is TrackingState.OK:
                    # This trial has started to track, record how long it was lost for
                    frames_lost[repeat_idx].append(tracking_frames[repeat_idx])
                    distances_lost[repeat_idx].append(
                        tracking_distances[repeat_idx])
                    times_lost[repeat_idx].append(
                        current_tracking_time[repeat_idx] -
                        prev_tracking_time[repeat_idx])
                    tracking_frames[repeat_idx] = 0
                    tracking_distances[repeat_idx] = 0
                    prev_tracking_time[repeat_idx] = current_tracking_time[
                        repeat_idx]
                    is_tracking[repeat_idx] = True

                # Update the current tracking information
                tracking_frames[repeat_idx] += 1
                tracking_distances[repeat_idx] += np.linalg.norm(
                    frame_result.motion.location)
                current_tracking_time[repeat_idx] = frame_result.timestamp

                # Turn loop closures into distances. We don't need to worry about origins because everything is GT frame
                if len(frame_result.loop_edges) > 0:
                    loop_distances, loop_angles = compute_loop_distances_and_angles(
                        frame_result.pose,
                        (
                            timestamps_to_pose[repeat_idx][timestamp]
                            for timestamp in frame_result.loop_edges
                            if timestamp in timestamps_to_pose[
                                repeat_idx]  # they should all be in there, but for safety, check
                        ))
                else:
                    loop_distances, loop_angles = [], []

                # Build the frame error
                frame_error = make_frame_error(
                    trial_result=trial_results[repeat_idx],
                    frame_result=frame_result,
                    image=images[frame_idx],
                    system=system,
                    repeat_index=repeat_idx,
                    loop_distances=loop_distances,
                    loop_angles=loop_angles,
                    # Compute the error in the absolute estimated pose (if available)
                    absolute_error=make_pose_error(
                        scaled_poses[repeat_idx],  # The
                        frame_result.pose)
                    if scaled_poses[repeat_idx] is not None else None,
                    # Compute the error of the motion relative to the true motion
                    relative_error=make_pose_error(scaled_motions[repeat_idx],
                                                   frame_result.motion)
                    if scaled_motions[repeat_idx] is not None else None,
                    # Compute the error between the motion and the average estimated motion
                    noise=make_pose_error(scaled_motions[repeat_idx],
                                          average_motion)
                    if scaled_motions[repeat_idx] is not None
                    and average_motion is not None else None,
                    systemic_error=make_pose_error(average_motion,
                                                   frame_result.motion)
                    if average_motion is not None else None)
                estimate_errors[repeat_idx].append(frame_error)

        # Add any accumulated tracking information left over at the end
        if len(trial_results[0].results) > 0:
            for repeat_idx, tracking in enumerate(is_tracking):
                if tracking:
                    frames_found[repeat_idx].append(
                        tracking_frames[repeat_idx])
                    distances_found[repeat_idx].append(
                        tracking_distances[repeat_idx])
                    times_found[repeat_idx].append(
                        current_tracking_time[repeat_idx] -
                        prev_tracking_time[repeat_idx])
                else:
                    frames_lost[repeat_idx].append(tracking_frames[repeat_idx])
                    distances_lost[repeat_idx].append(
                        tracking_distances[repeat_idx])
                    times_lost[repeat_idx].append(
                        current_tracking_time[repeat_idx] -
                        prev_tracking_time[repeat_idx])

        # Once we've tallied all the results, either succeed or fail based on the number of results.
        if len(estimate_errors) <= 0 or any(
                len(trial_errors) <= 0 for trial_errors in estimate_errors):
            return FrameErrorResult(
                metric=self,
                trial_results=trial_results,
                success=False,
                message="No measurable errors for these trajectories")
        return make_frame_error_result(
            metric=self,
            trial_results=trial_results,
            errors=[
                TrialErrors(frame_errors=estimate_errors[repeat],
                            frames_lost=frames_lost[repeat],
                            frames_found=frames_found[repeat],
                            times_lost=times_lost[repeat],
                            times_found=times_found[repeat],
                            distances_lost=distances_lost[repeat],
                            distances_found=distances_found[repeat])
                for repeat, trial_result in enumerate(trial_results)
            ])
예제 #30
0
def make_frame_error(
        trial_result: TrialResult,
        frame_result: FrameResult,
        image: typing.Union[None, Image],
        system: typing.Union[None, VisionSystem],
        repeat_index: int,
        absolute_error: typing.Union[None, PoseError],
        relative_error: typing.Union[None, PoseError],
        noise: typing.Union[None, PoseError],
        systemic_error: typing.Union[None, PoseError],
        loop_distances: typing.Iterable[float],
        loop_angles: typing.Iterable[float]
) -> FrameError:
    """
    Construct a frame_error object from a context
    The frame error copies data from it's linked objects (like the system or image)
    to avoid having to dereference them later.
    This function makes sure that data is consistent.

    It takes the system and image, even though the trial result and frame result should refer to those
    because you're usually creating lots of FrameError objects at the same time, so you should load
    those objects _once_, and pass them in each time.
    Just because you've loaded the object doesn't mean the FrameResult has that object

    :param trial_result: The trial result producing this FrameError
    :param frame_result: The FrameResult for the specific frame this error corresponds to
    :param image: The specific image this error corresponds to. Will be pulled from frame_result if None.
    :param system: The system that produced the trial_result. Will be pulled from the trial result if None.
    :param repeat_index: The repeat index of the trial result, for identification within the set
    :param absolute_error: The error in the estimated pose, in an absolute reference frame.
    :param relative_error: The error in teh estimated motion, relative to the previous frame.
    :param noise: The error between this particular motion estimate, and the average motion estimate from all trials.
    :param systemic_error: The difference between the mean estimate and the true motion. Should be near zero.
    :param loop_distances: The distance to other images to which this image has a loop closure. Will usually be empty.
    :param loop_angles: The angle to other images to which this image has a loop closure.
    :return: A FrameError object, containing the errors, and related metadata.
    """
    # Make sure the image we're given is the same as the one from the frame_result, without reloading it
    if image is None:
        image = frame_result.image
    else:
        with no_auto_dereference(type(frame_result)):
            if isinstance(frame_result.image, bson.ObjectId):
                image_id = frame_result.image
            else:
                image_id = frame_result.image.pk
        if image_id != image.pk:
            image = frame_result.image

    # Make sure the given system matches the trial result, avoiding loading it unnecessarily
    if system is None:
        system = trial_result.system
    else:
        with no_auto_dereference(type(trial_result)):
            if isinstance(trial_result.system, bson.ObjectId):
                system_id = trial_result.system
            else:
                system_id = trial_result.system.pk
        if system_id != system.pk:
            system = trial_result.system

    # Check that the loop distances and angles are the same
    loop_distances = list(loop_distances)
    loop_angles = list(loop_angles)
    if len(loop_distances) != len(loop_angles):
        raise ValueError("Loop distances and loop angles must always be the same length, "
                         f"was {loop_distances} and {loop_angles}")

    # Read the system properties from the trial result
    system_properties = system.get_properties(None, trial_result.settings)
    image_properties = image.get_properties()
    return FrameError(
        trial_result=trial_result,
        image=image,
        repeat=repeat_index,
        timestamp=frame_result.timestamp,
        motion=frame_result.motion,
        processing_time=frame_result.processing_time,
        loop_distances=loop_distances,
        loop_angles=loop_angles,
        num_features=frame_result.num_features,
        num_matches=frame_result.num_matches,
        tracking=frame_result.tracking_state,
        absolute_error=absolute_error,
        relative_error=relative_error,
        noise=noise,
        systemic_error=systemic_error,
        system_properties={str(k): json_value(v) for k, v in system_properties.items()},
        image_properties={str(k): json_value(v) for k, v in image_properties.items()}
    )