Esempio n. 1
0
    async def test_iter_gridfs(self):
        gfs = AsyncIOMotorGridFSBucket(self.db)

        async def cleanup():
            await self.db.fs.files.delete_many({})
            await self.db.fs.chunks.delete_many({})

        await cleanup()

        # Empty iterator.
        async for _ in gfs.find({'_id': 1}):
            self.fail()

        data = b'data'

        for n_files in 1, 2, 10:
            for i in range(n_files):
                async with gfs.open_upload_stream(filename='filename') as f:
                    await f.write(data)

            # Force extra batches to test iteration.
            j = 0
            async for _ in gfs.find({'filename': 'filename'}).batch_size(3):
                j += 1

            self.assertEqual(j, n_files)
            await cleanup()

        await gfs.upload_from_stream_with_id(1,
                                             'filename',
                                             source=data,
                                             chunk_size_bytes=1)
        cursor = gfs.find({'_id': 1})
        await cursor.fetch_next
        gout = cursor.next_object()
        chunks = []
        async for chunk in gout:
            chunks.append(chunk)

        self.assertEqual(len(chunks), len(data))
        self.assertEqual(b''.join(chunks), data)
Esempio n. 2
0
async def findlargemp4fileffmpeg(starttime, endtime):
    #print("begin findlargemp4fileffmpeg")
    mp4list = []
    client = AsyncIOMotorClient(ServerParameters.mongodbpath)
    db = client.jt808

    bucket = AsyncIOMotorGridFSBucket(db, "eventuploadvideos")
    cursor = bucket.find({
        "uploadDate": {
            '$gt': starttime,
            '$lte': endtime
        },
        "filename": {
            "$regex": ".mp4$"
        }
    })
    filelist = await cursor.to_list(100000)

    ccount = 0
    for fi in filelist:
        if fi["length"] > 1000000:
            print(fi)
            if os.path.exists(fi["filename"]):
                os.remove(fi["filename"])
            ds = await bucket.open_download_stream(fi["_id"])
            f = open("input" + fi["filename"], 'wb')
            bbb = await ds.read()
            f.write(bbb)
            f.close()
            ds.close()
            converttstoh264("input" + fi["filename"], fi["filename"])
            if os.path.exists("input" + fi["filename"]):
                os.remove("input" + fi["filename"])
            # 保存到bucket
            try:
                if os.path.exists(fi["filename"]):
                    uf = open(fi["filename"], "rb")
                    ubbb = uf.read()
                    uf.close()
                    os.remove(fi["filename"])
                    bucket.delete(fi["_id"])
                    uds = bucket.open_upload_stream_with_id(
                        fi["_id"], fi["filename"])
                    await uds.write(ubbb)
                    uds.close()
                    ccount = ccount + 1
                    logging.info("convert %s %s", fi["_id"], fi["filename"])
            except BaseException as e:
                logging.error(e)
    logging.info("end findlargemp4fileffmpeg total %s convert %s",
                 len(filelist), ccount)
    return
Esempio n. 3
0
    async def test_iter_gridfs(self):
        gfs = AsyncIOMotorGridFSBucket(self.db)

        async def cleanup():
            await self.db.fs.files.delete_many({})
            await self.db.fs.chunks.delete_many({})

        await cleanup()

        # Empty iterator.
        async for _ in gfs.find({'_id': 1}):
            self.fail()

        data = b'data'

        for n_files in 1, 2, 10:
            for i in range(n_files):
                async with gfs.open_upload_stream(filename='filename') as f:
                    await f.write(data)

            # Force extra batches to test iteration.
            j = 0
            async for _ in gfs.find({'filename': 'filename'}).batch_size(3):
                j += 1

            self.assertEqual(j, n_files)
            await cleanup()

        await gfs.upload_from_stream_with_id(
            1, 'filename', source=data, chunk_size_bytes=1)
        cursor = gfs.find({'_id': 1})
        await cursor.fetch_next
        gout = cursor.next_object()
        chunks = []
        async for chunk in gout:
            chunks.append(chunk)

        self.assertEqual(len(chunks), len(data))
        self.assertEqual(b''.join(chunks), data)
Esempio n. 4
0
async def mix_two_files(config, song_a_name, song_b_name, bpm_a, bpm_b,
                        desired_bpm, mix_name, scenario_name,
                        transition_points, entry_point, exit_point,
                        num_songs_a, mix_id, mix_db: Collection,
                        fs: AsyncIOMotorGridFSBucket):
    # check that mongodb mix file exists
    mix_mongo = await mix_db.find_one({"_id": ObjectId(mix_id)})
    if mix_mongo:

        # read the original wav files
        song_a_path = f"{config['song_analysis_path']}/{song_a_name}"
        song_a_data = b""
        song_b_path = f"{config['song_analysis_path']}/{song_b_name}"
        song_b_data = b""
        cursor_a = fs.find({"filename": song_a_name})
        async for grid_data in cursor_a:
            song_a_data = grid_data.read()
        with open(song_a_path, 'wb') as a_f:
            a_f.write(song_a_data)
        # song_a = util.read_wav_file(config, song_a_path, identifier='songA')
        song_a = util.read_wav_file(config,
                                    io.BytesIO(song_a_data),
                                    identifier='songA')

        cursor_b = fs.find({"filename": song_b_name})
        async for grid_data in cursor_b:
            song_b_data = grid_data.read()
        with open(song_b_path, 'wb') as b_f:
            b_f.write(song_b_data)
        # song_b = util.read_wav_file(config, song_b_path, identifier='songB')
        song_b = util.read_wav_file(config,
                                    io.BytesIO(song_a_data),
                                    identifier='songB')

        # if num_songs_a > 1:
        #     song_a = util.read_wav_file(config, f"{config['mix_path']}/{song_a_name}", identifier='songA')
        # else:
        #     song_a = util.read_wav_file(config, f"{config['song_path']}/{song_a_name}", identifier='songA')
        # song_b = util.read_wav_file(config, f"{config['song_path']}/{song_b_name}", identifier='songB')
        # song_a = util.read_wav_file(config, f"{config['song_path']}/{song_a_name}", identifier='songA')
        # song_b = util.read_wav_file(config, f"{config['song_path']}/{song_b_name}", identifier='songB')

        update_data = {"progress": 20}
        mix_update0 = await mix_db.update_one({"_id": ObjectId(mix_id)},
                                              {"$set": update_data})
        if not mix_update0:
            print("mix update #0 failed")

        # TSL = Transition Segment Length
        tsl_list = [
            config['transition_midpoint'],
            config['transition_length'] - config['transition_midpoint']
        ]

        # 1 match tempo of both songs before analysis
        # TODO write adjusted songs to db
        if desired_bpm != bpm_a:
            song_a_adjusted, song_b_adjusted = bpm_match.match_bpm_desired(
                config, song_a, song_b, desired_bpm, bpm_a, bpm_b)
        else:
            song_a_adjusted, song_b_adjusted = bpm_match.match_bpm_first(
                config, song_a, bpm_a, song_b, bpm_b)

        update_data = {"progress": 40}
        mix_update1 = await mix_db.update_one({"_id": ObjectId(mix_id)},
                                              {"$set": update_data})
        if not mix_update1:
            print("mix update #1 failed")

        # 2. analyse songs
        if transition_points:
            transition_points['b'] = round(
                transition_points['a'] +
                (transition_points['d'] - transition_points['c']), 3)
            transition_points['x'] = round(
                transition_points['a'] +
                (transition_points['e'] - transition_points['c']), 3)
        if not transition_points:
            then = time.time()
            transition_points = analysis.get_transition_points(
                config, song_a_adjusted, song_b_adjusted, exit_point,
                entry_point, tsl_list)
            now = time.time()
            print("INFO - Analysing file took: %0.1f seconds. \n" %
                  (now - then))

        update_data = {"transition_points": transition_points, "progress": 60}
        mix_update2 = await mix_db.update_one({"_id": ObjectId(mix_id)},
                                              {"$set": update_data})
        if not mix_update2:
            print("mix update #2 failed")

        print(f"Transition points (seconds): {transition_points}")
        print(
            f"Transition points (minutes): {util.get_length_for_transition_points(config, transition_points)}"
        )
        print(
            f"Transition interval lengths (C-D-E): {round(transition_points['d']-transition_points['c'], 3)}s, {round(transition_points['e']-transition_points['d'], 3)}s"
        )
        print(
            f"Transition interval lengths (A-B-X): {round(transition_points['b']-transition_points['a'], 3)}s, {round(transition_points['x']-transition_points['b'], 3)}s"
        )
        print()

        # 3. mix both songs
        then = time.time()
        frames = util.calculate_frames(config, song_a_adjusted,
                                       song_b_adjusted, transition_points)
        # print("Frames: %s" % frames)
        mixed_song = mixer.create_mixed_wav_file(config, song_a_adjusted,
                                                 song_b_adjusted,
                                                 transition_points, frames,
                                                 tsl_list, mix_name,
                                                 scenario_name)
        now = time.time()
        print("INFO - Mixing file took: %0.1f seconds" % (now - then))

        mix_name_wav = mixed_song['name']
        file_path_wav = mixed_song['path']
        with open(file_path_wav, 'rb') as f:
            grid_in = fs.open_upload_stream(mix_name_wav)
            await grid_in.write(f.read())
            await grid_in.close()
        update_data = {"title": mix_name_wav, "progress": 80}
        mix_update3 = await mix_db.update_one({"_id": ObjectId(mix_id)},
                                              {"$set": update_data})
        if not mix_update3:
            print("mix update #3 failed")

        # 4. convert to mp3
        if mixed_song:
            mix_name_mp3 = converter.convert_result_to_mp3(
                config, mixed_song['name'])
            if mix_name_mp3:
                mixed_song['name_mp3'] = mix_name_mp3
                mixed_song['path_mp3'] = f"{config['mix_path']}/{mix_name_mp3}"

        mix_name_mp3 = mixed_song['name_mp3']
        file_path_mp3 = mixed_song['path_mp3']
        with open(file_path_mp3, 'rb') as f:
            grid_in = fs.open_upload_stream(mix_name_mp3)
            await grid_in.write(f.read())
            await grid_in.close()
        update_data = {"progress": 100, "title_mp3": mix_name_mp3}
        mix_update4 = await mix_db.update_one({"_id": ObjectId(mix_id)},
                                              {"$set": update_data})
        if not mix_update4:
            print("mix update #4 failed")

        # 5. export json data
        scenario_data = util.get_scenario(config, scenario_name)
        scenario_data['short_name'] = scenario_name
        new_num_songs = num_songs_a + 1
        json_data = util.export_transition_parameters_to_json(
            config, [song_a, song_b, mixed_song], transition_points,
            scenario_data, tsl_list, new_num_songs, desired_bpm)

        os.remove(song_a_path)
        os.remove(song_b_path)
        return json_data
    else:
        return error_response_model("Not Found", 404,
                                    f"Mix with id {mix_id} does not exist")
Esempio n. 5
0
class AsyncIOClient(MongoClient):
    """
High-level AsyncIOMotorClient subclass with additional methods added for ease-of-use,
having some automated conveniences and defaults.
    """
    _MONGO_URI = lambda _: getattr(Config, "MONGO_URI", None)
    _DEFAULT_COLLECTION = None
    _KWARGS = None
    _LOGGING_COND_GET = None
    _LOGGING_COND_POST = None
    _LOGGING_COND_PUT = None
    _LOGGING_COND_PATCH = None
    _LOGGING_COND_DELETE = None

    def __init__(self, mongo_uri=None, default_collection=None, **kwargs):
        self._MONGO_URI = mongo_uri or self._MONGO_URI
        if callable(self._MONGO_URI):
            self._MONGO_URI = self._MONGO_URI()
        self._DEFAULT_COLLECTION = default_collection or self._DEFAULT_COLLECTION

        if kwargs:
            self._KWARGS = kwargs.copy()

        for kwarg in kwargs.keys():
            if kwarg.lower() in ('logging_cond_get', 'logging_cond_post',
                                'logging_cond_put', 'logging_cond_patch',
                                'logging_cond_delete'):
                setattr(self, kwarg.upper(), kwargs.pop(kwarg))

        MongoClient.__init__(self, self._MONGO_URI, **kwargs)

        db = self.get_default_database()
        logger.info("db detected '{}' of type '{}'".format(db.name, type(db.name)))
        if not getattr(db, "name", None) or db.name == "None":
            logger.warning("database not provided in MONGO_URI, assign with method set_database")
            logger.warning("gridfsbucket not instantiated due to missing database")
        else:
            global SUPPORT_ASYNCIO_BUCKET
            if SUPPORT_ASYNCIO_BUCKET:
                logger.debug("gridfsbucket instantiated under self.FILES")
                self.FILES = GridFSBucket(db)
            else:
                logger.warning("gridfsbucket not instantiated due to missing 'tornado' package")
                self.FILES = None

    def __repr__(self):
        db = self.get_default_database()

        if not getattr(db, "name", None) or db.name == "None":
            return "<cervmongo.AsyncIOClient>"
        else:
            return f"<cervmongo.AsyncIOClient.{db.name}>"

    def _process_record_id_type(self, record):
        one = False
        if isinstance(record, str):
            one = True
            if "$oid" in record:
                record = {"$in": [json_load(record), record]}
            else:
                try:
                    record = {"$in": [DOC_ID.__supertype__(record), record]}
                except:
                    pass
        elif isinstance(record, DOC_ID.__supertype__):
            record = record
            one = True
        elif isinstance(record, dict):
            if "$oid" in record or "$regex" in record:
                record = json_dump(record)
                record = json_load(record)
                one = True
        return (record, one)

    def set_database(self, database):
        Config.set_mongo_db(database)
        if self._KWARGS:
            AsyncIOClient.__init__(self, mongo_uri=Config.MONGO_URI, default_collection=self._DEFAULT_COLLECTION, **self._KWARGS)
        else:
            AsyncIOClient.__init__(self, mongo_uri=Config.MONGO_URI, default_collection=self._DEFAULT_COLLECTION)

    def COLLECTION(self, collection:str):

        self._DEFAULT_COLLECTION = collection


        class CollectionClient:
            __parent__ = CLIENT = self
            # INFO: variables
            _DEFAULT_COLLECTION = collection
            _MONGO_URI = self._MONGO_URI
            # INFO: general methods
            GENERATE_ID = self.GENERATE_ID
            COLLECTION = self.COLLECTION
            # INFO: GridFS file operations
            UPLOAD = self.UPLOAD
            DOWNLOAD = self.DOWNLOAD
            ERASE = self.ERASE
            # INFO: truncated Collection methods
            INDEX = partial(self.INDEX, collection)
            ADD_FIELD = partial(self.ADD_FIELD, collection)
            REMOVE_FIELD = partial(self.REMOVE_FIELD, collection)
            DELETE = partial(self.DELETE, collection)
            GET = partial(self.GET, collection)
            POST = partial(self.POST, collection)
            PUT = partial(self.PUT, collection)
            PATCH = partial(self.PATCH, collection)
            REPLACE = partial(self.REPLACE, collection)
            SEARCH = partial(self.SEARCH, collection)
            PAGINATED_QUERY = partial(self.PAGINATED_QUERY, collection)
            def __repr__(s):
                return "<cervmongo.AsyncIOClient.CollectionClient>"
            def get_client(s):
                return s.CLIENT
        return CollectionClient()

    async def PAGINATED_QUERY(self, collection, limit:int=20,
                                sort:PAGINATION_SORT_FIELDS=PAGINATION_SORT_FIELDS["_id"],
                                after:str=None, before:str=None,
                                page:int=None, endpoint:str="/",
                                ordering:int=-1, query:dict={}, **kwargs):
        """
            Returns paginated results of collection w/ query.

            Available pagination methods:
             - **Cursor-based (default)**
                - after
                - before
                - limit (results per page, default 20)
             - **Time-based** (a datetime field must be selected)
                - sort (set to datetime field)
                - after (records after this time)
                - before (records before this time)
                - limit (results per page, default 20)
             - **Offset-based** (not recommended)
                - limit (results per page, default 20)
                - page
        """
        collection = collection or self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"

        if isinstance(sort, ENUM.__supertype__):
            sort = sort.value

        total_docs = await self.GET(collection, query, count=True, empty=0)

        if not page:
            if sort == "_id":
                pagination_method = "cursor"
            else:
                pagination_method = "time"
            cursor = await self.GET(collection, query,
                                    limit=limit, key=sort, before=before,
                                    after=after, sort=ordering, empty=[])
        else:
            assert page >= 1, "page must be equal to or greater than 1"
            pagination_method = "offset"
            cursor = await self.GET(collection, query,
                                    perpage=limit, key=sort, page=page,
                                    sort=ordering, empty=[])

        results = [ record async for record in cursor ]

        # INFO: determine 'cursor' template
        if sort == "_id":
            template = "_{_id}"
        else:
            template = "{date}_{_id}"

        new_after = None
        new_before = None

        if results:
            _id = results[-1]["_id"]
            try:
                date = results[-1][sort].isoformat()
            except:
                date = None
            if len(results) == limit:
                new_after = template.format(_id=_id, date=date)

            _id = results[0]["_id"]
            try:
                date = results[0][sort].isoformat()
            except:
                date = None
            if any((after, before)):
                new_before = template.format(_id=_id, date=date)

            if pagination_method in ("cursor", "time"):
                if before:
                    check_ahead = await self.GET(collection, query,
                                            limit=limit, key=sort, before=new_before, empty=0, count=True)
                    if not check_ahead:
                        new_before = None
                elif after:
                    check_ahead = await self.GET(collection, query,
                                            limit=limit, key=sort, after=new_after, empty=0, count=True)
                    if not check_ahead:
                        new_after = None

        response = {
            "data": results,
            "details": {
                "pagination_method": pagination_method,
                "query": dict_to_query(query),
                "sort": sort,
                "unique_id": getattr(self, "_UNIQUE_ID", "_id"),
                "total": total_docs,
                "count": len(results),
                "limit": limit
                }
            }

        endpoint = endpoint

        # TODO: Refactor
        if pagination_method in ("cursor", "time"):
            response["details"]["cursors"] = {
                  "after": new_after,
                  "before": new_before
                }
            before_url_template = "{endpoint}?sort={sort}&limit={limit}&before={before}"
            after_url_template = "{endpoint}?sort={sort}&limit={limit}&after={after}"
        else: # INFO: pagination_method == "offset"
            response["details"]["cursors"] = {
                  "prev_page": page - 1 if page > 1 else None,
                  "next_page": page + 1 if (page * limit) <= total_docs else None
                }
            before_url_template = "{endpoint}?sort={sort}&limit={limit}&page={page}"
            after_url_template = "{endpoint}?sort={sort}&limit={limit}&page={page}"

        if new_before:
            response["details"]["previous"] = before_url_template.format(
                                                                    endpoint=endpoint,
                                                                    sort=sort,
                                                                    page=page,
                                                                    limit=limit,
                                                                    after=new_after,
                                                                    before=new_before)
        else:
            response["details"]["previous"] = None

        if new_after:
            response["details"]["next"] = after_url_template.format(
                                                                    endpoint=endpoint,
                                                                    sort=sort,
                                                                    page=page,
                                                                    limit=limit,
                                                                    after=new_after,
                                                                    before=new_before)
        else:
            response["details"]["next"] = None

        return response
    PAGINATED_QUERY.clean_kwargs = lambda kwargs: _clean_kwargs(ONLY=("limit", "sort", "after",
                                            "before", "page", "endpoint", "query"), kwargs=kwargs)

    def GENERATE_ID(self, _id=None):
        if _id:
            return DOC_ID.__supertype__(_id)
        else:
            return DOC_ID.__supertype__()

    async def UPLOAD(self, fileobj, filename:str=None, content_type:str=None, extension:str=None, **kwargs):
        assert self.FILES, "GridFS instance not initialized, run method 'set_database' with the desired database and try again"
        fileobj = file_and_fileobj(fileobj)
        metadata = get_file_meta_information(fileobj, filename=filename, content_type=content_type, extension=extension)
        filename = metadata['filename']
        metadata.update(kwargs)
        file_id = await self.FILES.upload_from_stream(filename, fileobj, metadata=metadata)
        return file_id

    async def ERASE(self, filename_or_id, revision:int=-1):
        assert self.FILES, "GridFS instance not initialized, run method 'set_database' with the desired database and try again"
        fs_doc = await self.DOWNLOAD(filename_or_id, revision=revision)
        await self.FILES.delete(fs_doc._id)
        await fs_doc.close()

    async def DOWNLOAD(self, filename_or_id=None, revision:int=-1, skip:int=None, limit:int=None, sort:int=-1, **query):
        assert self.FILES, "GridFS instance not initialized, run method 'set_database' with the desired database and try again"
        revision = int(revision)
        if filename_or_id:
            if isinstance(filename_or_id, DOC_ID.__supertype__):
                return await self.FILES.open_download_stream(filename_or_id)
            else:
                return await self.FILES.open_download_stream_by_name(filename_or_id, revision=revision)

        return self.FILES.find(query, limit=limit, skip=skip, sort=sort, no_cursor_timeout=True)

    async def DELETE(self, collection, record, soft:bool=False, one:bool=False):
        db = self.get_default_database()
        if not collection:
            if hasattr(self, '_DEFAULT_COLLECTION'):
                collection = self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"

        o_collection = collection[:]

        collection = db[collection]

        if not isinstance(record, (list, tuple)):
            record, _one = self._process_record_id_type(record)
            one = _one if _one else one
            if _one:
                record = {"_id": record}
        else:
            record = self._process_record_id_type(record)[0]

        if soft:
            data_record = await self.GET(o_collection, record)
            try:
                await self.PUT("deleted."+o_collection, data_record)
            except:
                data_record.pop("_id")
                await self.PUT("deleted."+o_collection, data_record)

        if isinstance(record, (str, ObjectId)):
            return await collection.delete_one({"_id": record})
        elif isinstance(record, dict):
            if one:
                return await collection.delete_one(record)
            else:
                return await collection.delete_many(record)
        else:
            results = []
            for _id in record:
                results.append(await collection.delete_one({"_id": _id}))
            return results

    def INDEX(self, collection, key:str="_id", sort:int=1, unique:bool=False, reindex:bool=False):
        db = self.get_default_database()
        if not collection:
            if hasattr(self, '_DEFAULT_COLLECTION'):
                collection = self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"

        collection = db[collection]

        name = "%sIndex%s" % (key, "Asc" if sort == 1 else "Desc")
        try:
            if not name in collection.index_information():
                collection.create_index([
                    (key, sort)], name=name, background=True, unique=unique)
        except:
            #print((_traceback()))
            pass

    async def ADD_FIELD(self, collection, field:str, value:typing.Union[typing.Dict, typing.List, str, int, float, bool]='', data=False, query:dict={}):
        if not collection:
            if hasattr(self, '_DEFAULT_COLLECTION'):
                collection = self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"

        query.update({field: {"$exists": False}})
        if data:
            records = await self.GET(collection, query, fields={
                data: True}, empty=[])
        else:
            records = await self.GET(collection, query, fields={
                "_id": True}, empty=[])

        for record in records:
            if data:
                await self.PATCH(collection, record["_id"], {"$set": {
                    field: record[data]}})
            else:
                await self.PATCH(collection, record["_id"], {"$set": {
                    field: value}})

    async def REMOVE_FIELD(self, collection, field:str, query:dict={}) -> None:
        if not collection:
            collection = self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"
        query.update({field: {"$exists": True}})
        records = await self.GET(collection, query, distinct=True)

        for record in records:
            await self.PATCH(collection, record, {"$unset": {field: ""}})

    async def GET(self, collection, id_or_query:typing.Union[DOC_ID, str, typing.Dict]={}, sort:int=1, key:str="_id", count:bool=None, search:str=None, fields:dict=None, page:int=None, perpage:int=False, limit:int=None, after:str=None, before:str=None, empty=None, distinct:str=None, one:bool=False, **kwargs):
        db = self.get_default_database()
        collection = collection or self._DEFAULT_COLLECTION
        assert collection, "collection not provided"

        if not isinstance(collection, (list, tuple, types.GeneratorType)):
            collection = [collection]
        cols = list(set(collection))
        results = []
        number_of_results = len(cols)

        if distinct == True:
            distinct = "_id"

        id_or_query, _one = self._process_record_id_type(id_or_query)
        one = _one if _one else one
        if _one:
            query = {"_id": id_or_query}
        else:
            query = id_or_query

        for collection in cols:
            collection = db[collection]

            if query or not search:
                if count and not limit:
                    if query:
                        results.append(await collection.count_documents(query, **kwargs))
                    else:
                        results.append(await collection.estimated_document_count(**kwargs))
                elif distinct:
                    cursor = await collection.distinct(distinct, filter=query, **kwargs)
                    results.append(sorted(cursor))
                elif perpage:
                    total = (page - 1) * perpage
                    cursor = collection.find(query, projection=fields, **kwargs)
                    results.append(cursor.sort([(key, sort)]).skip(total).limit(perpage))
                elif limit:
                    if any((query, after, before)):
                        query = {"$and": [
                                    query
                                ]}
                    if after or before:
                        if after:
                            sort_value, _id_value = after.split("_")
                            _id_value = DOC_ID.__supertype__(_id_value)
                            query["$and"].append({"$or": [
                                            {key: {"$lt": _id_value}}
                                        ]})
                            if key != "_id":
                                sort_value = dateparse(sort_value)
                                query["$and"][-1]["$or"].append({key: {"$lt": sort_value}, "_id": {"$lt": _id_value}})
                        elif before:
                            sort_value, _id_value = before.split("_")
                            _id_value = DOC_ID.__supertype__(_id_value)
                            query["$and"].append({"$or": [
                                            {key: {"$gt": _id_value}}
                                        ]})
                            if key != "_id":
                                sort_value = dateparse(sort_value)
                                query["$and"][-1]["$or"].append({key: {"$gt": sort_value}, "_id": {"$gt": _id_value}})

                    if count:
                        try:
                            cursor = await collection.count_documents(query, limit=limit, hint=[(key, sort)], **kwargs)
                        except:
                            cursor = len(await collection.find(query, fields, **kwargs).sort([(key, sort)]).to_list(limit))
                        results.append(cursor)
                    else:
                        cursor = collection.find(query, projection=fields, **kwargs).sort([(key, sort)]).limit(limit)
                        results.append(cursor)
                elif one:
                    val = await collection.find_one(query, projection=fields, sort=[(key, sort)], **kwargs)
                    results.append(val if val else empty)
                else:
                    cursor = collection.find(query, projection=fields, **kwargs).sort([(key, sort)])
                    results.append(cursor)
            elif search:
                try:
                    if count:
                        results.append(await cursor.count_documents({"$text": {"$search": search}}))
                    elif distinct:
                        results.append(await collection.distinct(distinct, filter={"$text": {"$search": search}}))
                    else:
                        cursor = collection.find({"$text": {"$search": search}})
                        if perpage:
                            total = (page - 1) * perpage
                            results.append(cursor.sort([(key, sort)]).skip(total).limit(perpage))
                        else:
                            results.append(cursor.sort([(key, sort)]))
                except:
                    cursor = await collection.command('textIndex', search=search)
                    if count:
                        results.append(cursor.count())
                    elif distinct:
                        results.append(cursor.distinct(distinct))
                    else:
                        if perpage:
                            total = (page - 1) * perpage
                            results.append(cursor.sort([(key, sort)]).skip(total).limit(perpage))
                        else:
                            results.append(cursor.sort([(key, sort)]))
            else:
                raise Error("unidentified error")

        if number_of_results == 1:
            return results[0]
        else:
            return results

    async def SEARCH(self, collection, search:str, **kwargs):
        if not collection:
            if hasattr(self, '_DEFAULT_COLLECTION'):
                collection = self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"

        return await self.GET(collection, search=search, **kwargs)

    async def POST(self, collection, record_or_records:typing.Union[typing.List, typing.Dict]):
        db = self.get_default_database()
        collection = collection or self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"
        collection = db[collection]

        if isinstance(record_or_records, (list, tuple)):
            return await collection.insert_many(record_or_records)
        elif isinstance(record_or_records, dict):
            return await collection.insert_one(record_or_records)
        else:
            raise TypeError("invalid record type '{}' provided".format(type(record_or_records)))

    async def PUT(self, collection, record_or_records:typing.Union[typing.List, typing.Dict]):
        """
            creates or replaces record(s) with exact _id provided, _id is required with record object(s)

            returns original document, if replaced
        """
        db = self.get_default_database()
        collection = collection or self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"
        collection = db[collection]

        if isinstance(record_or_records, (list, tuple)):
            assert all([ record.get("_id", None) for record in record_or_records ]), "not all records provided contained an _id"
            return await collection.insert_many(record_or_records, ordered=False)
        elif isinstance(record_or_records, dict):
            assert record_or_records.get("_id", None), "no _id provided"
            query = {"_id": record_or_records["_id"]}
            return await collection.find_one_and_replace(query, record_or_records, upsert=True)
        else:
            raise TypeError("invalid record type '{}' provided".format(type(record_or_records)))

    async def REPLACE(self, collection, original, replacement:dict, upsert=False):
        db = self.get_default_database()
        if not collection:
            if hasattr(self, '_DEFAULT_COLLECTION'):
                collection = self._DEFAULT_COLLECTION
        assert collection, "collection must be of type str"

        collection = db[collection]

        return await collection.replace_one({"_id": original},
                    replacement, upsert=upsert)

    async def PATCH(self, collection, id_or_query:typing.Union[DOC_ID, typing.Dict, typing.List, str], updates:typing.Union[typing.Dict, typing.List], upsert:bool=False, w:int=1):
        db = self.get_default_database()
        collection = collection or self._DEFAULT_COLLECTION
        assert collection, "collection not provided"
        collection = db[collection]

        if w != 1:
            WRITE = WriteConcern(w=w)
            collection = collection.with_options(write_concern=WRITE)

        if isinstance(id_or_query, (str, DOC_ID.__supertype__)):
            assert isinstance(updates, dict), "updates must be dict"
            id_or_query, _ = self._process_record_id_type(id_or_query)
            query = {"_id": id_or_query}

            set_on_insert_id = {"$setOnInsert": query}
            updates.update(set_on_insert_id)

            results = await collection.update_one(query, updates, upsert=upsert)
            return results
        elif isinstance(id_or_query, dict):
            assert isinstance(updates, dict), "updates must be dict"
            results = await collection.update_many(id_or_query, updates, upsert=upsert)
            return results
        elif isinstance(id_or_query, (tuple, list)):
            assert isinstance(updates, (tuple, list)), "updates must be list or tuple"

            results = []
            for i, _id in enumerate(id_or_query):
                _id, _ = self._process_record_id_type(id_or_query)
                query = {"_id": _id}
                set_on_insert_id = {"$setOnInsert": query}
                updates[i].update(set_on_insert_id)

                result = await collection.update_one(query, updates[i], upsert=upsert)
                results.append(result)

            return results
        else:
            raise Error("unidentified error")
Esempio n. 6
0
class LockboxDB:
    """
    Holds databases for lockbox.
    """
    def __init__(self, host: str, port: int):
        # Set up fernet
        # Read from base64 encoded key
        if os.environ.get("LOCKBOX_CREDENTIAL_KEY"):
            key = os.environ.get("LOCKBOX_CREDENTIAL_KEY")
        # Read from key file
        elif os.environ.get("LOCKBOX_CREDENTIAL_KEY_FILE"):
            try:
                with open(os.environ.get("LOCKBOX_CREDENTIAL_KEY_FILE"),
                          "rb") as f:
                    key = base64.b64encode(f.read())
            except IOError as e:
                raise ValueError(
                    "Cannot read password encryption key file") from e
        else:
            raise ValueError(
                "Encryption key for passwords must be provided! Set LOCKBOX_CREDENTIAL_KEY or LOCKBOX_CREDENTIAL_KEY_FILE."
            )
        # Should raise ValueError if key is invalid
        self.fernet = Fernet(key)

        if os.environ.get("LOCKBOX_SCHOOL"):
            try:
                self.school_code = int(os.environ["LOCKBOX_SCHOOL"])
            except ValueError as e:
                logger.error(f"Invalid school code: {e}")
                self.school_code = None
        else:
            self.school_code = None

        self.client = AsyncIOMotorClient(host, port)
        self._private_db = self.client["lockbox"]
        self._shared_db = self.client["shared"]
        self._private_instance = MotorAsyncIOInstance(self._private_db)
        self._shared_instance = MotorAsyncIOInstance(self._shared_db)
        self._shared_gridfs = AsyncIOMotorGridFSBucket(self._shared_db)

        self.LockboxFailureImpl = self._private_instance.register(
            documents.LockboxFailure)
        self.FillFormResultImpl = self._private_instance.register(
            documents.FillFormResult)
        self.UserImpl = self._private_instance.register(documents.User)
        self.FormGeometryEntryImpl = self._private_instance.register(
            documents.FormGeometryEntry)
        self.CachedFormGeometryImpl = self._private_instance.register(
            documents.CachedFormGeometry)
        self.TaskImpl = self._private_instance.register(documents.Task)

        self.FormFieldImpl = self._shared_instance.register(
            documents.FormField)
        self.FormImpl = self._shared_instance.register(documents.Form)
        self.CourseImpl = self._shared_instance.register(documents.Course)
        self.FormFillingTestImpl = self._shared_instance.register(
            documents.FormFillingTest)
        self.LockboxFailureImplShared = self._shared_instance.register(
            documents.LockboxFailure)
        self.FillFormResultImplShared = self._shared_instance.register(
            documents.FillFormResult)

        self._scheduler = scheduler.Scheduler(self)
        tasks.set_task_handlers(self._scheduler)
        # Current school day, set by the check day task
        # Used as a fallback & indicator of whether the day's been checked
        # None when the day has not been checked
        self.current_day = None

    async def init(self):
        """
        Initialize the databases and task scheduler.
        """
        await self.UserImpl.ensure_indexes()
        await self.CourseImpl.ensure_indexes()
        await self.CachedFormGeometryImpl.collection.drop()
        await self.CachedFormGeometryImpl.ensure_indexes()
        await self._scheduler.start()

        # Re-schedule the check day task if current day is not checked
        if self.current_day is None:
            await self._reschedule_check_day()

    def private_db(self) -> AsyncIOMotorDatabase:
        """
        Get a reference to the private database.
        """
        return self._private_db

    def shared_db(self) -> AsyncIOMotorDatabase:
        """
        Get a reference to the shared database.
        """
        return self._shared_db

    def shared_gridfs(self) -> AsyncIOMotorGridFSBucket:
        """
        Get a reference to the shared GridFS bucket.
        """
        return self._shared_gridfs

    async def _reschedule_check_day(self) -> None:
        """
        Reschedule the check day task.

        If the task is set to run later today, no action will be taken.
        If the task will not run today or does not exist, it will be scheduled immediately.
        """
        check_task = await self.TaskImpl.find_one(
            {"kind": documents.TaskType.CHECK_DAY.value})
        if check_task is None:
            # Create check task if it does not exist
            await self._scheduler.create_task(kind=documents.TaskType.CHECK_DAY
                                              )
        # Check if the task will run later today
        # If the check task is set to run on a different date then make it run now
        elif check_task.next_run_at.replace(
                tzinfo=datetime.timezone.utc).astimezone(
                    tasks.LOCAL_TZ).date() > datetime.datetime.today().date():
            check_task.next_run_at = datetime.datetime.utcnow()
            await check_task.commit()
            self._scheduler.update()

    async def populate_user_courses(self,
                                    user,
                                    courses: typing.List[TimetableItem],
                                    clear_previous: bool = True) -> None:
        """
        Populate a user's courses, creating new Course documents if new courses are encountered.

        If clear_previous is True, all previous courses will be cleared.
        However, the Course documents in the shared database will not be touched, since they might
        also be referred to by other users.
        """
        if clear_previous:
            user.courses = []
        else:
            user.courses = user.courses or []
        # Populate courses collection
        for course in courses:
            db_course = await self.CourseImpl.find_one(
                {"course_code": course.course_code})
            if db_course is None:
                db_course = self.CourseImpl(
                    course_code=course.course_code,
                    teacher_name=course.course_teacher_name)
                # Without this, known_slots for different courses will all point to the same instance of list
                db_course.known_slots = []
            else:
                # Make sure the teacher name is set
                if not db_course.teacher_name:
                    db_course.teacher_name = course.course_teacher_name
            # Fill in known slots
            slot_str = f"{course.course_cycle_day}-{course.course_period}"
            if slot_str not in db_course.known_slots:
                db_course.known_slots.append(slot_str)
            await db_course.commit()
            if db_course.pk not in user.courses:
                user.courses.append(db_course.pk)
        await user.commit()

    async def create_user(self) -> str:
        """
        Create a new user.

        Returns token on success.
        """
        token = secrets.token_hex(32)
        await self.UserImpl(token=token).commit()
        return token

    async def modify_user(
            self,
            token: str,
            login: str = None,
            password: str = None,  # pylint: disable=unused-argument
            active: bool = None,
            grade: int = None,
            first_name: str = None,
            last_name: str = None,
            **kwargs) -> None:
        """
        Modify user data.

        Also verifies credentials if modifying login or password.
        """
        user = await self.UserImpl.find_one({"token": token})
        if user is None:
            raise LockboxDBError("Bad token", LockboxDBError.BAD_TOKEN)
        try:
            if login is not None:
                user.login = login
            if password is not None:
                user.password = self.fernet.encrypt(password.encode("utf-8"))
            if active is not None:
                user.active = active
            if grade is not None:
                user.grade = grade
            if first_name is not None:
                user.first_name = first_name
            if last_name is not None:
                user.last_name = last_name
            # Verify user credentials if username and password are both present
            # and at least one is being modified
            if user.login is not None and user.password is not None and (
                    login is not None or password is not None):
                logger.info(f"Verifying credentials for login {user.login}")
                try:
                    async with TDSBConnects() as session:
                        await session.login(login, password)
                        info = await session.get_user_info()
                        schools = info.schools
                        if self.school_code is None:
                            if len(schools) != 1:
                                logger.info(
                                    f"Login {user.login} has an invalid number of schools."
                                )
                                raise LockboxDBError(
                                    f"TDSB Connects reported {len(schools)} schools; nffu can only handle 1 school",
                                    LockboxDBError.OTHER)
                            school = schools[0]
                        else:
                            for s in schools:
                                if s.code == self.school_code:
                                    school = s
                                    break
                            else:
                                logger.info(
                                    f"Login {user.login} is not in the configured school"
                                )
                                raise LockboxDBError(
                                    f"You do not appear to be in the school nffu was set up for (#{self.school_code}); nffu can only handle 1 school",
                                    LockboxDBError.OTHER)
                        user.email = info.email
                        # Try to get user grade, first name, and last name
                        try:
                            user.grade = int(
                                info._data["SchoolCodeList"][0]["StudentInfo"]
                                ["CurrentGradeLevel"])
                            # CurrentGradeLevel increments once per *calendar* year
                            # So the value is off-by-one during the first half of the school year
                            # School year is in the form XXXXYYYY, e.g. 20202021
                            if not school.school_year.endswith(
                                    str(datetime.datetime.now().year)):
                                user.grade += 1
                        except (ValueError, KeyError, IndexError):
                            pass
                        try:
                            user.first_name = info._data["SchoolCodeList"][0][
                                "StudentInfo"]["FirstName"]
                            user.last_name = info._data["SchoolCodeList"][0][
                                "StudentInfo"]["LastName"]
                        except (ValidationError, KeyError, IndexError):
                            pass
                except aiohttp.ClientResponseError as e:
                    logger.info(f"TDSB login error for login {user.login}")
                    # Invalid credentials, clean up and raise
                    if e.code == 401:
                        raise LockboxDBError(
                            "Incorrect TDSB credentials",
                            LockboxDBError.INVALID_FIELD) from e
                    raise LockboxDBError(
                        f"HTTP error while logging into TDSB Connects: {str(e)}"
                    ) from e
                # Now we know credentials are valid
                await user.commit()
                logger.info(f"Credentials good for login {user.login}")
                await self._scheduler.create_task(
                    kind=documents.TaskType.POPULATE_COURSES, owner=user)
            else:
                await user.commit()

            # If user is active and has complete set of credentials, make a fill form task for them
            if user.active and user.login is not None and user.password is not None:
                task = await self.TaskImpl.find_one({
                    "kind":
                    documents.TaskType.FILL_FORM.value,
                    "owner":
                    user
                })
                if task is None:
                    logger.info(
                        f"Creating new fill form task for user {user.pk}")
                    # Calculate next run time
                    # This time will always be in the next day, so check if it's possible to do it today
                    run_at = tasks.next_run_time(tasks.FILL_FORM_RUN_TIME)
                    if (run_at - datetime.timedelta(days=1)).replace(
                            tzinfo=None) >= datetime.datetime.utcnow():
                        run_at -= datetime.timedelta(days=1)
                    task = await self._scheduler.create_task(
                        kind=documents.TaskType.FILL_FORM,
                        run_at=run_at,
                        owner=user)
                    # Reschedule the check day task as well
                    # The task might not exist if this is the first user
                    await self._reschedule_check_day()
            # If active is set to false for this user, remove their fill form task
            elif not active:
                task = await self.TaskImpl.find_one({
                    "kind":
                    documents.TaskType.FILL_FORM.value,
                    "owner":
                    user
                })
                if task is not None:
                    logger.info(f"Deleting fill form task for user {user.pk}")
                    await task.remove()
                    self._scheduler.update()
        except ValidationError as e:
            raise LockboxDBError(f"Invalid field: {e}",
                                 LockboxDBError.INVALID_FIELD) from e

    async def get_user(self, token: str) -> typing.Dict[str, typing.Any]:
        """
        Get user data as a formatted dict.
        """
        user = await self.UserImpl.find_one({"token": token})
        if user is None:
            raise LockboxDBError("Bad token", LockboxDBError.BAD_TOKEN)
        return user.dump()

    async def delete_user(self, token: str) -> None:
        """
        Delete a user by token.
        """
        user = await self.UserImpl.find_one({"token": token})
        if user is None:
            raise LockboxDBError("Bad token", LockboxDBError.BAD_TOKEN)
        # Delete screenshots
        if user.last_fill_form_result is not None:
            if user.last_fill_form_result.form_screenshot_id is not None:
                try:
                    await self._shared_gridfs.delete(
                        user.last_fill_form_result.form_screenshot_id)
                except gridfs.NoFile:
                    logger.warning(
                        f"Fill form: Failed to delete previous result form screenshot for user {user.pk}: No file"
                    )
            if user.last_fill_form_result.confirmation_screenshot_id is not None:
                try:
                    await self._shared_gridfs.delete(
                        user.last_fill_form_result.confirmation_screenshot_id)
                except gridfs.NoFile:
                    logger.warning(
                        f"Fill form: Failed to delete previous result conformation page screenshot for user {user.pk}: No file"
                    )
        # Delete fill form task
        task = await self.TaskImpl.find_one({
            "kind":
            documents.TaskType.FILL_FORM.value,
            "owner":
            user
        })
        if task is not None:
            logger.info(f"Deleting fill form task for user {user.pk}")
            await task.remove()
            self._scheduler.update()
        await user.remove()

    async def delete_user_error(self, token: str, eid: str) -> None:
        """
        Delete an error by id for a user.
        """
        try:
            result = await self.UserImpl.collection.update_one(
                {"token": token},
                {"$pull": {
                    "errors": {
                        "_id": bson.ObjectId(eid)
                    }
                }})
        except bson.errors.InvalidId as e:
            raise LockboxDBError("Bad error id") from e
        if result.matched_count == 0:
            raise LockboxDBError("Bad token", LockboxDBError.BAD_TOKEN)
        if result.modified_count == 0:
            raise LockboxDBError("Bad error id")

    async def update_user_courses(self, token: str) -> None:
        """
        Refresh the detected courses for a user.
        """
        user = await self.UserImpl.find_one({"token": token})
        if user is None:
            raise LockboxDBError("Bad token", LockboxDBError.BAD_TOKEN)
        if user.login is None or user.password is None:
            raise LockboxDBError("Cannot update courses: Missing credentials",
                                 LockboxDBError.STATE_CONFLICT)
        # Make sure the password is valid
        try:
            self.fernet.decrypt(user.password).decode("utf-8")
        except InvalidToken as e:
            logger.critical(f"User {user.pk}'s password cannot be decrypted")
            raise LockboxDBError(
                "Internal server error: Cannot decrypt password",
                LockboxDBError.INTERNAL_ERROR) from e
        await self._scheduler.create_task(
            kind=documents.TaskType.POPULATE_COURSES, owner=user)

    async def update_all_courses(self) -> None:
        """
        Refresh the detected courses for ALL users.
        """
        batch_size = 3
        if os.environ.get("LOCKBOX_UPDATE_COURSES_BATCH_SIZE"):
            try:
                b = int(os.environ["LOCKBOX_UPDATE_COURSES_BATCH_SIZE"])
                if b < 1:
                    raise ValueError("Batch size cannot be less than 1")
                batch_size = b
            except ValueError as e:
                logger.error(
                    f"Update all courses: Invalid batch size specified by env var (defaulted to {batch_size}): {e}"
                )
        interval = 60
        if os.environ.get("LOCKBOX_UPDATE_COURSES_INTERVAL"):
            try:
                i = int(os.environ["LOCKBOX_UPDATE_COURSES_INTERVAL"])
                if i < 0:
                    raise ValueError("Interval cannot be less than 0")
                interval = i
            except ValueError as e:
                logger.error(
                    f"Update all courses: Invalid interval specified by env var (defaulted to {interval}s): {e}"
                )
        run_at = datetime.datetime.utcnow()
        batch = 0
        async for user in self.UserImpl.find({
                "login": {
                    "$ne": None
                },
                "password": {
                    "$ne": None
                }
        }):
            await self._scheduler.create_task(
                documents.TaskType.POPULATE_COURSES, run_at, user)
            batch += 1
            if batch >= batch_size:
                batch = 0
                run_at += datetime.timedelta(seconds=interval)

    async def get_form_geometry(self, token: str, url: str,
                                grab_screenshot: bool) -> dict:
        """
        Get the form geometry for a given form URL.
        """
        user = await self.UserImpl.find_one({"token": token})
        if user is None:
            raise LockboxDBError("Bad token", LockboxDBError.BAD_TOKEN)
        if user.login is None or user.password is None:
            raise LockboxDBError("Cannot sign into form: Missing credentials",
                                 LockboxDBError.STATE_CONFLICT)
        geom = await self.CachedFormGeometryImpl.find_one({"url": url})
        # Check if screenshot requirement is satisfied
        if geom is not None and grab_screenshot:
            screenshot_valid = False
            # If screenshot ID exists, check the GridFS bucket to make sure it's actually valid
            # the screenshot data may have been deleted by fenetre
            if geom.screenshot_file_id is not None:
                async for _ in self._shared_gridfs.find(
                    {"_id": geom.screenshot_file_id}):
                    screenshot_valid = True
                    break
        else:
            screenshot_valid = True
        # If this form was never requested before,
        # or the screenshot requirement is not satisfied AND the operation is not already pending
        if geom is None or (not screenshot_valid
                            and geom.geometry is not None):
            # If this is a re-run, clear the old result
            if geom is not None:
                await geom.remove()
            # Re-make the geometry
            try:
                geom = self.CachedFormGeometryImpl(
                    url=url,
                    requested_by=token,
                    geometry=None,
                    grab_screenshot=grab_screenshot)
            except ValidationError as e:
                raise LockboxDBError(f"Invalid field: {e}",
                                     LockboxDBError.INVALID_FIELD) from e
            await geom.commit()
            # Create tasks to get form geometry and clean up
            await self._scheduler.create_task(
                documents.TaskType.GET_FORM_GEOMETRY,
                owner=user,
                argument=str(geom.pk))
            await self._scheduler.create_task(
                documents.TaskType.REMOVE_OLD_FORM_GEOMETRY,
                datetime.datetime.utcnow() + datetime.timedelta(minutes=15),
                owner=user,
                argument=str(geom.pk))
            return {
                "geometry": None,
                "auth_required": None,
                "screenshot_id": None
            }
        # Result pending
        if geom.geometry is None and geom.response_status is None:
            return {
                "geometry": None,
                "auth_required": None,
                "screenshot_id": None
            }
        # Result exists
        if geom.response_status is None:
            return {
                "geometry": [e.dump() for e in geom.geometry],
                "auth_required": geom.auth_required,
                "screenshot_id": str(geom.screenshot_file_id)
            }
        return {
            "geometry": [e.dump() for e in geom.geometry],
            "screenshot_id": str(geom.screenshot_file_id),
            "auth_required": geom.auth_required,
            "error": geom.error,
            "status": geom.response_status
        }

    async def get_tasks(self) -> typing.List[dict]:
        """
        Get a list of serialized tasks.
        """
        return [
            task.dump()
            async for task in self.TaskImpl.find().sort("next_run_at", 1).sort(
                "retry_count", -1).sort("is_running", -1)
        ]

    async def find_form_test_context(self, oid: str):
        return await self.FormFillingTestImpl.find_one(
            {"_id": bson.ObjectId(oid)})

    async def start_form_test(self, oid: str, token: str):
        """
        Start filling in a test form
        """

        user = await self.UserImpl.find_one({"token": token})
        if user is None:
            raise LockboxDBError("Bad token", LockboxDBError.BAD_TOKEN)
        await self._scheduler.create_task(
            kind=documents.TaskType.TEST_FILL_FORM, owner=user, argument=oid)
        await self._scheduler.create_task(
            run_at=datetime.datetime.utcnow() + datetime.timedelta(hours=6),
            kind=documents.TaskType.REMOVE_OLD_TEST_RESULTS,
            argument=oid)