class DBMiddleware():
    def __init__(self, config):
        self.config = config
        self.dbConnector = Database(config)

        self._fetchProjectSettings()
        self.sqlBuilder = SQLStringBuilder(config)
        self.annoParser = AnnotationParser(config)

    def _fetchProjectSettings(self):
        # AI controller URI
        aiControllerURI = self.config.getProperty('Server', 'aiController_uri')
        if aiControllerURI is None or aiControllerURI.strip() == '':
            # no AI backend configured
            aiControllerURI = None

        # LabelUI drawing styles
        with open(
                self.config.getProperty(
                    'LabelUI',
                    'styles_file',
                    type=str,
                    fallback='modules/LabelUI/static/json/styles.json'),
                'r') as f:
            styles = json.load(f)

        # Image backdrops for index screen
        with open(
                self.config.getProperty(
                    'Project',
                    'backdrops_file',
                    type=str,
                    fallback='modules/LabelUI/static/json/backdrops.json'),
                'r') as f:
            backdrops = json.load(f)

        # Welcome message for UI tutorial
        with open(
                self.config.getProperty(
                    'Project',
                    'welcome_message_file',
                    type=str,
                    fallback=
                    'modules/LabelUI/static/templates/welcome_message.html'),
                'r') as f:
            welcomeMessage = f.readlines()

        self.projectSettings = {
            'projectName':
            self.config.getProperty('Project', 'projectName'),
            'projectDescription':
            self.config.getProperty('Project', 'projectDescription'),
            'indexURI':
            self.config.getProperty('Server',
                                    'index_uri',
                                    type=str,
                                    fallback='/'),
            'dataServerURI':
            self.config.getProperty('Server', 'dataServer_uri'),
            'aiControllerURI':
            aiControllerURI,
            'dataType':
            self.config.getProperty('Project', 'dataType', fallback='images'),
            'classes':
            self.getClassDefinitions(),
            'enableEmptyClass':
            self.config.getProperty('Project',
                                    'enableEmptyClass',
                                    fallback='no'),
            'annotationType':
            self.config.getProperty('Project', 'annotationType'),
            'predictionType':
            self.config.getProperty('Project', 'predictionType'),
            'showPredictions':
            self.config.getProperty('LabelUI',
                                    'showPredictions',
                                    fallback='yes'),
            'showPredictions_minConf':
            self.config.getProperty('LabelUI',
                                    'showPredictions_minConf',
                                    type=float,
                                    fallback=0.5),
            'carryOverPredictions':
            self.config.getProperty('LabelUI',
                                    'carryOverPredictions',
                                    fallback='no'),
            'carryOverRule':
            self.config.getProperty('LabelUI',
                                    'carryOverRule',
                                    fallback='maxConfidence'),
            'carryOverPredictions_minConf':
            self.config.getProperty('LabelUI',
                                    'carryOverPredictions_minConf',
                                    type=float,
                                    fallback=0.75),
            'defaultBoxSize_w':
            self.config.getProperty('LabelUI',
                                    'defaultBoxSize_w',
                                    type=int,
                                    fallback=10),
            'defaultBoxSize_h':
            self.config.getProperty('LabelUI',
                                    'defaultBoxSize_h',
                                    type=int,
                                    fallback=10),
            'minBoxSize_w':
            self.config.getProperty('Project',
                                    'box_minWidth',
                                    type=int,
                                    fallback=1),
            'minBoxSize_h':
            self.config.getProperty('Project',
                                    'box_minHeight',
                                    type=int,
                                    fallback=1),
            'numImagesPerBatch':
            self.config.getProperty('LabelUI',
                                    'numImagesPerBatch',
                                    type=int,
                                    fallback=1),
            'minImageWidth':
            self.config.getProperty('LabelUI',
                                    'minImageWidth',
                                    type=int,
                                    fallback=300),
            'numImageColumns_max':
            self.config.getProperty('LabelUI',
                                    'numImageColumns_max',
                                    type=int,
                                    fallback=1),
            'defaultImage_w':
            self.config.getProperty('LabelUI',
                                    'defaultImage_w',
                                    type=int,
                                    fallback=800),
            'defaultImage_h':
            self.config.getProperty('LabelUI',
                                    'defaultImage_h',
                                    type=int,
                                    fallback=600),
            'styles':
            styles['styles'],
            'backdrops':
            backdrops,
            'welcomeMessage':
            welcomeMessage,
            'demoMode':
            self.config.getProperty('Project',
                                    'demoMode',
                                    type=bool,
                                    fallback=False)
        }

    def _assemble_annotations(self, cursor):
        response = {}
        while True:
            b = cursor.fetchone()
            if b is None:
                break

            imgID = str(b['image'])
            if not imgID in response:
                response[imgID] = {
                    'fileName': b['filename'],
                    'predictions': {},
                    'annotations': {},
                    'last_checked': None
                }
            viewcount = b['viewcount']
            if viewcount is not None:
                response[imgID]['viewcount'] = viewcount
            last_checked = b['last_checked']
            if last_checked is not None:
                if response[imgID]['last_checked'] is None:
                    response[imgID]['last_checked'] = last_checked
                else:
                    response[imgID]['last_checked'] = max(
                        response[imgID]['last_checked'], last_checked)

            # parse annotations and predictions
            entryID = str(b['id'])
            if b['ctype'] is not None:
                colnames = self.sqlBuilder.getColnames(b['ctype'])
                entry = {}
                for c in colnames:
                    value = b[c]
                    if isinstance(value, datetime):
                        value = value.timestamp()
                    elif isinstance(value, UUID):
                        value = str(value)
                    entry[c] = value

                if b['ctype'] == 'annotation':
                    response[imgID]['annotations'][entryID] = entry
                elif b['ctype'] == 'prediction':
                    response[imgID]['predictions'][entryID] = entry

        return response

    def getProjectSettings(self):
        '''
            Queries the database for general project-specific metadata, such as:
            - Classes: names, indices, default colors
            - Annotation type: one of {class labels, positions, bboxes}
        '''
        return self.projectSettings

    def getProjectInfo(self):
        '''
            Returns safe, shareable information about the project.
        '''
        return {
            'projectName':
            self.projectSettings['projectName'],
            'projectDescription':
            self.projectSettings['projectDescription'],
            'demoMode':
            self.config.getProperty('Project',
                                    'demoMode',
                                    type=bool,
                                    fallback=False),
            'backdrops':
            self.projectSettings['backdrops']['images']
        }

    def getClassDefinitions(self):
        '''
            Returns a dictionary with entries for all classes in the project.
        '''
        classdef = {
            'entries': {
                'default': {}  # default group for ungrouped label classes
            }
        }
        schema = self.config.getProperty('Database', 'schema')

        # query data
        sql = '''
            SELECT 'group' AS type, id, NULL as idx, name, color, parent, NULL AS keystroke FROM {schema}.labelclassgroup
            UNION ALL
            SELECT 'class' AS type, id, idx, name, color, labelclassgroup, keystroke FROM {schema}.labelclass;
        '''.format(schema=schema)
        classData = self.dbConnector.execute(sql, None, 'all')

        # assemble entries first
        allEntries = {}
        numClasses = 0
        for cl in classData:
            id = str(cl['id'])
            entry = {
                'id': id,
                'name': cl['name'],
                'color': cl['color'],
                'parent':
                str(cl['parent']) if cl['parent'] is not None else None,
            }
            if cl['type'] == 'group':
                entry['entries'] = {}
            else:
                entry['index'] = cl['idx']
                entry['keystroke'] = cl['keystroke']
                numClasses += 1
            allEntries[id] = entry

        # transform into tree
        def _find_parent(tree, parentID):
            if parentID is None:
                return tree['entries']['default']
            elif 'id' in tree and tree['id'] == parentID:
                return tree
            elif 'entries' in tree:
                for ek in tree['entries'].keys():
                    rv = _find_parent(tree['entries'][ek], parentID)
                    if rv is not None:
                        return rv
                return None
            else:
                return None

        allEntries['default'] = {'name': '(other)', 'entries': {}}
        allEntries = {'entries': allEntries}
        for key in list(allEntries['entries'].keys()):
            if key == 'default':
                continue
            if key in allEntries['entries']:
                entry = allEntries['entries'][key]
                parentID = entry['parent']
                del entry['parent']

                if 'entries' in entry and parentID is None:
                    # group, but no parent: append to root directly
                    allEntries['entries'][key] = entry

                else:
                    # move item
                    parent = _find_parent(allEntries, parentID)
                    parent['entries'][key] = entry
                    del allEntries['entries'][key]

        classdef = allEntries
        classdef['numClasses'] = numClasses
        return classdef

    def getBatch_fixed(self, username, data):
        '''
            Returns entries from the database based on the list of data entry identifiers specified.
        '''
        # query
        sql = self.sqlBuilder.getFixedImagesQueryString(
            self.projectSettings['demoMode'])

        # parse results
        queryVals = (
            tuple(UUID(d) for d in data),
            username,
            username,
        )
        if self.projectSettings['demoMode']:
            queryVals = (tuple(UUID(d) for d in data), )

        with self.dbConnector.execute_cursor(sql, queryVals) as cursor:
            try:
                response = self._assemble_annotations(cursor)
                # self.dbConnector.conn.commit()
            except Exception as e:
                print(e)
                # self.dbConnector.conn.rollback()
            finally:
                pass
                # cursor.close()
        return {'entries': response}

    def getBatch_auto(self,
                      username,
                      order='unlabeled',
                      subset='default',
                      limit=None):
        '''
            TODO: description
        '''
        # query
        sql = self.sqlBuilder.getNextBatchQueryString(
            order, subset, self.projectSettings['demoMode'])

        # limit (TODO: make 128 a hyperparameter)
        if limit is None:
            limit = 128
        else:
            limit = min(int(limit), 128)

        # parse results
        queryVals = (
            username,
            limit,
            username,
        )
        if self.projectSettings['demoMode']:
            queryVals = (limit, )

        with self.dbConnector.execute_cursor(sql, queryVals) as cursor:
            response = self._assemble_annotations(cursor)

        return {'entries': response}

    def getBatch_timeRange(self,
                           minTimestamp,
                           maxTimestamp,
                           userList,
                           skipEmptyImages=False,
                           limit=None):
        '''
            Returns images that have been annotated within the given time range and/or
            by the given user(s). All arguments are optional.
            Useful for reviewing existing annotations.
        '''
        # query string
        sql = self.sqlBuilder.getDateQueryString(minTimestamp, maxTimestamp,
                                                 userList, skipEmptyImages)

        # check validity and provide arguments
        queryVals = []
        if userList is not None:
            queryVals.append(tuple(userList))
        if minTimestamp is not None:
            queryVals.append(minTimestamp)
        if maxTimestamp is not None:
            queryVals.append(maxTimestamp)
        if skipEmptyImages and userList is not None:
            queryVals.append(tuple(userList))

        # limit (TODO: make 128 a hyperparameter)
        if limit is None:
            limit = 128
        else:
            limit = min(int(limit), 128)
        queryVals.append(limit)

        if userList is not None:
            queryVals.append(tuple(userList))

        # query and parse results
        with self.dbConnector.execute_cursor(sql, tuple(queryVals)) as cursor:
            try:
                response = self._assemble_annotations(cursor)
                # self.dbConnector.conn.commit()
            except Exception as e:
                print(e)
                # self.dbConnector.conn.rollback()
            finally:
                pass
                # cursor.close()
        return {'entries': response}

    def get_timeRange(self, userList, skipEmptyImages=False):
        '''
            Returns two timestamps denoting the temporal limits within which
            images have been viewed by the users provided in the userList.
            Arguments:
            - userList: string (single user name) or list of strings (multiple).
                        Can also be None; in this case all annotations will be
                        checked.
            - skipEmptyImages: if True, only images that contain at least one
                               annotation will be considered.
        '''
        # query string
        sql = self.sqlBuilder.getTimeRangeQueryString(userList,
                                                      skipEmptyImages)

        arguments = (None if userList is None else tuple(userList))
        result = self.dbConnector.execute(sql, (arguments, ), numReturn=1)

        if result is not None and len(result):
            return {
                'minTimestamp': result[0]['mintimestamp'],
                'maxTimestamp': result[0]['maxtimestamp'],
            }
        else:
            return {'error': 'no annotations made'}

    def submitAnnotations(self, username, submissions):
        '''
            Sends user-provided annotations to the database.
        '''
        if self.projectSettings['demoMode']:
            return 0

        # assemble values
        colnames = getattr(QueryStrings_annotation,
                           self.projectSettings['annotationType']).value
        values_insert = []
        values_update = []

        meta = (None if not 'meta' in submissions else json.dumps(
            submissions['meta']))

        # for deletion: remove all annotations whose image ID matches but whose annotation ID is not among the submitted ones
        ids = []

        viewcountValues = []
        for imageKey in submissions['entries']:
            entry = submissions['entries'][imageKey]

            try:
                lastChecked = entry['timeCreated']
                lastTimeRequired = entry['timeRequired']
                if lastTimeRequired is None: lastTimeRequired = 0
            except:
                lastChecked = datetime.now(tz=pytz.utc)
                lastTimeRequired = 0

            if 'annotations' in entry and len(entry['annotations']):
                for annotation in entry['annotations']:
                    # assemble annotation values
                    annotationTokens = self.annoParser.parseAnnotation(
                        annotation)
                    annoValues = []
                    for cname in colnames:
                        if cname == 'id':
                            if cname in annotationTokens:
                                # cast and only append id if the annotation is an existing one
                                annoValues.append(UUID(
                                    annotationTokens[cname]))
                                ids.append(UUID(annotationTokens[cname]))
                        elif cname == 'image':
                            annoValues.append(UUID(imageKey))
                        elif cname == 'label' and annotationTokens[
                                cname] is not None:
                            annoValues.append(UUID(annotationTokens[cname]))
                        elif cname == 'timeCreated':
                            try:
                                annoValues.append(
                                    dateutil.parser.parse(
                                        annotationTokens[cname]))
                            except:
                                annoValues.append(datetime.now(tz=pytz.utc))
                        elif cname == 'timeRequired':
                            timeReq = annotationTokens[cname]
                            if timeReq is None: timeReq = 0
                            annoValues.append(timeReq)
                        elif cname == 'username':
                            annoValues.append(username)
                        elif cname in annotationTokens:
                            annoValues.append(annotationTokens[cname])
                        elif cname == 'unsure':
                            if 'unsure' in annotationTokens and annotationTokens[
                                    'unsure'] is not None:
                                annoValues.append(annotationTokens[cname])
                            else:
                                annoValues.append(False)
                        elif cname == 'meta':
                            annoValues.append(meta)
                        else:
                            annoValues.append(None)
                    if 'id' in annotationTokens:
                        # existing annotation; update
                        values_update.append(tuple(annoValues))
                    else:
                        # new annotation
                        values_insert.append(tuple(annoValues))

            viewcountValues.append(
                (username, imageKey, 1, lastChecked, lastTimeRequired, meta))

        schema = self.config.getProperty('Database', 'schema')

        # delete all annotations that are not in submitted batch
        imageKeys = list(UUID(k) for k in submissions['entries'])
        if len(imageKeys):
            if len(ids):
                sql = '''
                    DELETE FROM {schema}.annotation WHERE username = %s AND id IN (
                        SELECT idQuery.id FROM (
                            SELECT * FROM {schema}.annotation WHERE id NOT IN %s
                        ) AS idQuery
                        JOIN (
                            SELECT * FROM {schema}.annotation WHERE image IN %s
                        ) AS imageQuery ON idQuery.id = imageQuery.id);
                '''.format(schema=schema)
                self.dbConnector.execute(sql, (
                    username,
                    tuple(ids),
                    tuple(imageKeys),
                ))
            else:
                # no annotations submitted; delete all annotations submitted before
                sql = '''
                    DELETE FROM {schema}.annotation WHERE username = %s AND image IN %s;
                '''.format(schema=schema)
                self.dbConnector.execute(sql, (
                    username,
                    tuple(imageKeys),
                ))

        # insert new annotations
        if len(values_insert):
            sql = '''
                INSERT INTO {}.annotation ({})
                VALUES %s ;
            '''.format(
                schema,
                ', '.join(colnames[1:])  # skip 'id' column
            )
            self.dbConnector.insert(sql, values_insert)

        # update existing annotations
        if len(values_update):
            updateCols = ''
            for col in colnames:
                if col == 'label':
                    updateCols += '{col} = UUID(e.{col}),'.format(col=col)
                elif col == 'timeRequired':
                    # we sum the required times together
                    updateCols += '{col} = COALESCE(a.{col},0) + COALESCE(e.{col},0),'.format(
                        col=col)
                else:
                    updateCols += '{col} = e.{col},'.format(col=col)

            sql = '''
                UPDATE {schema}.annotation AS a
                SET {updateCols}
                FROM (VALUES %s) AS e({colnames})
                WHERE e.id = a.id;
            '''.format(schema=schema,
                       updateCols=updateCols.strip(','),
                       colnames=', '.join(colnames))
            self.dbConnector.insert(sql, values_update)

        # viewcount table
        sql = '''
            INSERT INTO {}.image_user (username, image, viewcount, last_checked, last_time_required, meta)
            VALUES %s 
            ON CONFLICT (username, image) DO UPDATE SET viewcount = image_user.viewcount + 1, last_checked = EXCLUDED.last_checked, last_time_required = EXCLUDED.last_time_required, meta = EXCLUDED.meta;
        '''.format(schema)

        self.dbConnector.insert(sql, viewcountValues)

        return 0
Esempio n. 2
0
class DBMiddleware():
    def __init__(self, config):
        self.config = config
        self.dbConnector = Database(config)

        self.project_immutables = {
        }  # project settings that cannot be changed (project shorthand -> {settings})

        self._fetchProjectSettings()
        self.sqlBuilder = SQLStringBuilder()
        self.annoParser = AnnotationParser()

    def _fetchProjectSettings(self):
        # AI controller URI
        aiControllerURI = self.config.getProperty('Server', 'aiController_uri')
        if aiControllerURI is None or aiControllerURI.strip() == '':
            # no AI backend configured
            aiControllerURI = None

        # global, project-independent settings
        self.globalSettings = {
            'indexURI':
            self.config.getProperty('Server',
                                    'index_uri',
                                    type=str,
                                    fallback='/'),
            'dataServerURI':
            self.config.getProperty('Server', 'dataServer_uri'),
            'aiControllerURI':
            aiControllerURI
        }

        # default styles
        try:
            # check if custom default styles are provided
            self.defaultStyles = json.load(
                open('config/default_ui_settings.json', 'r'))
        except:
            # resort to built-in styles
            self.defaultStyles = json.load(
                open(
                    'modules/ProjectAdministration/static/json/default_ui_settings.json',
                    'r'))

    def _assemble_annotations(self, project, cursor, hideGoldenQuestionInfo):
        response = {}
        while True:
            b = cursor.fetchone()
            if b is None:
                break

            imgID = str(b['image'])
            if not imgID in response:
                response[imgID] = {
                    'fileName': b['filename'],
                    'predictions': {},
                    'annotations': {},
                    'last_checked': None
                }
            viewcount = b['viewcount']
            if viewcount is not None:
                response[imgID]['viewcount'] = viewcount
            last_checked = b['last_checked']
            if last_checked is not None:
                if response[imgID]['last_checked'] is None:
                    response[imgID]['last_checked'] = last_checked
                else:
                    response[imgID]['last_checked'] = max(
                        response[imgID]['last_checked'], last_checked)

            if not hideGoldenQuestionInfo:
                response[imgID]['isGoldenQuestion'] = b['isgoldenquestion']

            # parse annotations and predictions
            entryID = str(b['id'])
            if b['ctype'] is not None:
                colnames = self.sqlBuilder.getColnames(
                    self.project_immutables[project]['annotationType'],
                    self.project_immutables[project]['predictionType'],
                    b['ctype'])
                entry = {}
                for c in colnames:
                    value = b[c]
                    if isinstance(value, datetime):
                        value = value.timestamp()
                    elif isinstance(value, UUID):
                        value = str(value)
                    entry[c] = value

                if b['ctype'] == 'annotation':
                    response[imgID]['annotations'][entryID] = entry
                elif b['ctype'] == 'prediction':
                    response[imgID]['predictions'][entryID] = entry

        return response

    def _set_images_requested(self, project, imageIDs):
        '''
            Sets column "last_requested" of relation "image"
            to the current date. This is done during image
            querying to signal that an image has been requested,
            but not (yet) viewed.
        '''
        # prepare insertion values
        now = datetime.now(tz=pytz.utc)
        vals = []
        for key in imageIDs:
            vals.append(key)
        if len(vals):
            queryStr = sql.SQL('''
                UPDATE {id_img}
                SET last_requested = %s
                WHERE id IN %s;
            ''').format(id_img=sql.Identifier(project, 'image'))
            self.dbConnector.execute(queryStr, (
                now,
                tuple(vals),
            ), None)

    def _get_sample_metadata(self, metaType):
        '''
            Returns a dummy annotation or prediction for the sample
            image in the "exampleData" folder, depending on the "metaType"
            specified (i.e., labels, points, boundingBoxes, or segmentationMasks).
        '''
        if metaType == 'labels':
            return {
                'id': '00000000-0000-0000-0000-000000000000',
                'label': '00000000-0000-0000-0000-000000000000',
                'confidence': 1.0,
                'priority': 1.0,
                'viewcount': None
            }
        elif metaType == 'points' or metaType == 'boundingBoxes':
            return {
                'id': '00000000-0000-0000-0000-000000000000',
                'label': '00000000-0000-0000-0000-000000000000',
                'x': 0.542959427207637,
                'y': 0.5322069489713102,
                'width': 0.6133651551312653,
                'height': 0.7407598263401316,
                'confidence': 1.0,
                'priority': 1.0,
                'viewcount': None
            }
        elif metaType == 'segmentationMasks':
            # read segmentation mask from disk
            segmask = Image.open(
                'modules/LabelUI/static/exampleData/sample_segmentationMask.tif'
            )
            segmask, width, height = helpers.imageToBase64(segmask)
            return {
                'id': '00000000-0000-0000-0000-000000000000',
                'width': width,
                'height': height,
                'segmentationmask': segmask,
                'confidence': 1.0,
                'priority': 1.0,
                'viewcount': None
            }
        else:
            return {}

    def get_project_immutables(self, project):
        if project not in self.project_immutables:
            queryStr = 'SELECT annotationType, predictionType, demoMode FROM aide_admin.project WHERE shortname = %s;'
            result = self.dbConnector.execute(queryStr, (project, ), 1)
            if result and len(result):
                self.project_immutables[project] = {
                    'annotationType': result[0]['annotationtype'],
                    'predictionType': result[0]['predictiontype'],
                    'demoMode': helpers.checkDemoMode(project,
                                                      self.dbConnector)
                }
            else:
                return None
        return self.project_immutables[project]

    def get_dynamic_project_settings(self, project):
        queryStr = 'SELECT ui_settings FROM aide_admin.project WHERE shortname = %s;'
        result = self.dbConnector.execute(queryStr, (project, ), 1)
        result = json.loads(result[0]['ui_settings'])

        # complete styles with defaults where necessary (may be required for project that got upgraded from v1)
        result = helpers.check_args(result, self.defaultStyles)

        return result

    def getProjectSettings(self, project):
        '''
            Queries the database for general project-specific metadata, such as:
            - Classes: names, indices, default colors
            - Annotation type: one of {class labels, positions, bboxes}
        '''
        # publicly available info from DB
        projSettings = self.getProjectInfo(project)

        # label classes
        projSettings['classes'] = self.getClassDefinitions(project)

        # static and dynamic project settings and properties from configuration file
        projSettings = {
            **projSettings,
            **self.get_project_immutables(project),
            **self.get_dynamic_project_settings(project),
            **self.globalSettings
        }

        # append project shorthand to AIController URI
        if 'aiControllerURI' in projSettings and projSettings[
                'aiControllerURI'] is not None and len(
                    projSettings['aiControllerURI']):
            projSettings['aiControllerURI'] = os.path.join(
                projSettings['aiControllerURI'], project) + '/'

        return projSettings

    def getProjectInfo(self, project):
        '''
            Returns safe, shareable information about the project
            (i.e., users don't need to be part of the project to see these data).
        '''
        queryStr = '''
            SELECT shortname, name, description, demoMode,
            interface_enabled, archived, ai_model_enabled,
            ai_model_library, ai_alcriterion_library,
            segmentation_ignore_unlabeled
            FROM aide_admin.project
            WHERE shortname = %s
        '''
        result = self.dbConnector.execute(queryStr, (project, ), 1)[0]

        # provide flag if AI model is available
        aiModelAvailable = all([
            result['ai_model_enabled'], result['ai_model_library'] is not None
            and len(result['ai_model_library']),
            result['ai_alcriterion_library'] is not None
            and len(result['ai_alcriterion_library'])
        ])

        return {
            'projectShortname':
            result['shortname'],
            'projectName':
            result['name'],
            'projectDescription':
            result['description'],
            'demoMode':
            result['demomode'],
            'interface_enabled':
            result['interface_enabled'] and not result['archived'],
            'ai_model_available':
            aiModelAvailable,
            'segmentation_ignore_unlabeled':
            result['segmentation_ignore_unlabeled']
        }

    def getClassDefinitions(self, project, showHidden=False):
        '''
            Returns a dictionary with entries for all classes in the project.
        '''

        # query data
        if showHidden:
            hiddenSpec = ''
        else:
            hiddenSpec = 'WHERE hidden IS false'
        queryStr = sql.SQL('''
            SELECT 'group' AS type, id, NULL as idx, name, color, parent, NULL AS keystroke, NULL AS hidden FROM {}
            UNION ALL
            SELECT 'class' AS type, id, idx, name, color, labelclassgroup, keystroke, hidden FROM {}
            {};
            ''').format(sql.Identifier(project, 'labelclassgroup'),
                        sql.Identifier(project, 'labelclass'),
                        sql.SQL(hiddenSpec))

        classData = self.dbConnector.execute(queryStr, None, 'all')

        # assemble entries first
        allEntries = {}
        numClasses = 0
        for cl in classData:
            id = str(cl['id'])
            entry = {
                'id': id,
                'name': cl['name'],
                'color': cl['color'],
                'parent':
                str(cl['parent']) if cl['parent'] is not None else None,
                'hidden': cl['hidden']
            }
            if cl['type'] == 'group':
                entry['entries'] = {}
            else:
                entry['index'] = cl['idx']
                entry['keystroke'] = cl['keystroke']
                numClasses += 1
            allEntries[id] = entry

        # transform into tree
        def _find_parent(tree, parentID):
            if parentID is None:
                return None
            elif 'id' in tree and tree['id'] == parentID:
                return tree
            elif 'entries' in tree:
                for ek in tree['entries'].keys():
                    rv = _find_parent(tree['entries'][ek], parentID)
                    if rv is not None:
                        return rv
                return None
            else:
                return None

        allEntries = {'entries': allEntries}
        for key in list(allEntries['entries'].keys()):
            entry = allEntries['entries'][key]
            parentID = entry['parent']
            del entry['parent']

            if parentID is None:
                # entry or group with no parent: append to root directly
                allEntries['entries'][key] = entry

            else:
                # move item
                parent = _find_parent(allEntries, parentID)
                parent['entries'][key] = entry
                del allEntries['entries'][key]

        allEntries['numClasses'] = numClasses
        return allEntries

    def getBatch_fixed(self,
                       project,
                       username,
                       data,
                       hideGoldenQuestionInfo=True):
        '''
            Returns entries from the database based on the list of data entry identifiers specified.
        '''

        if not len(data):
            return {'entries': {}}

        # query
        projImmutables = self.get_project_immutables(project)
        queryStr = self.sqlBuilder.getFixedImagesQueryString(
            project, projImmutables['annotationType'],
            projImmutables['predictionType'], projImmutables['demoMode'])

        # parse results
        queryVals = (
            tuple(UUID(d) for d in data),
            username,
            username,
        )
        if projImmutables['demoMode']:
            queryVals = (tuple(UUID(d) for d in data), )

        with self.dbConnector.execute_cursor(queryStr, queryVals) as cursor:
            try:
                response = self._assemble_annotations(project, cursor,
                                                      hideGoldenQuestionInfo)
                # self.dbConnector.conn.commit()
            except Exception as e:
                print(e)
                # self.dbConnector.conn.rollback()
            finally:
                pass
                # cursor.close()

        # mark images as requested
        self._set_images_requested(project, response)

        return {'entries': response}

    def getBatch_auto(self,
                      project,
                      username,
                      order='unlabeled',
                      subset='default',
                      limit=None,
                      hideGoldenQuestionInfo=True):
        '''
            TODO: description
        '''
        # query
        projImmutables = self.get_project_immutables(project)
        queryStr = self.sqlBuilder.getNextBatchQueryString(
            project, projImmutables['annotationType'],
            projImmutables['predictionType'], order, subset,
            projImmutables['demoMode'])

        # limit (TODO: make 128 a hyperparameter)
        if limit is None:
            limit = 128
        else:
            limit = min(int(limit), 128)

        # parse results
        queryVals = (
            username,
            username,
            limit,
            username,
        )
        if projImmutables[
                'demoMode']:  #TODO: demoMode can now change dynamically
            queryVals = (limit, )

        with self.dbConnector.execute_cursor(queryStr, queryVals) as cursor:
            response = self._assemble_annotations(project, cursor,
                                                  hideGoldenQuestionInfo)

        # mark images as requested
        self._set_images_requested(project, response)

        return {'entries': response}

    def getBatch_timeRange(self,
                           project,
                           minTimestamp,
                           maxTimestamp,
                           userList,
                           skipEmptyImages=False,
                           limit=None,
                           goldenQuestionsOnly=False,
                           hideGoldenQuestionInfo=True):
        '''
            Returns images that have been annotated within the given time range and/or
            by the given user(s). All arguments are optional.
            Useful for reviewing existing annotations.
        '''
        # query string
        projImmutables = self.get_project_immutables(project)
        queryStr = self.sqlBuilder.getDateQueryString(
            project, projImmutables['annotationType'], minTimestamp,
            maxTimestamp, userList, skipEmptyImages, goldenQuestionsOnly)

        # check validity and provide arguments
        queryVals = []
        if userList is not None:
            queryVals.append(tuple(userList))
        if minTimestamp is not None:
            queryVals.append(minTimestamp)
        if maxTimestamp is not None:
            queryVals.append(maxTimestamp)
        if skipEmptyImages and userList is not None:
            queryVals.append(tuple(userList))

        # limit (TODO: make 128 a hyperparameter)
        if limit is None:
            limit = 128
        else:
            limit = min(int(limit), 128)
        queryVals.append(limit)

        if userList is not None:
            queryVals.append(tuple(userList))

        # query and parse results
        with self.dbConnector.execute_cursor(queryStr,
                                             tuple(queryVals)) as cursor:
            try:
                response = self._assemble_annotations(project, cursor,
                                                      hideGoldenQuestionInfo)
                # self.dbConnector.conn.commit()
            except Exception as e:
                print(e)
                # self.dbConnector.conn.rollback()
            finally:
                pass
                # cursor.close()

        # # mark images as requested
        # self._set_images_requested(project, response)

        return {'entries': response}

    def get_timeRange(self,
                      project,
                      userList,
                      skipEmptyImages=False,
                      goldenQuestionsOnly=False):
        '''
            Returns two timestamps denoting the temporal limits within which
            images have been viewed by the users provided in the userList.
            Arguments:
            - userList: string (single user name) or list of strings (multiple).
                        Can also be None; in this case all annotations will be
                        checked.
            - skipEmptyImages: if True, only images that contain at least one
                               annotation will be considered.
            - goldenQuestionsOnly: if True, only images flagged as golden questions
                                   will be shown.
        '''
        # query string
        queryStr = self.sqlBuilder.getTimeRangeQueryString(
            project, userList, skipEmptyImages, goldenQuestionsOnly)

        arguments = (None if userList is None else tuple(userList))
        result = self.dbConnector.execute(queryStr, (arguments, ), numReturn=1)

        if result is not None and len(result):
            return {
                'minTimestamp': result[0]['mintimestamp'],
                'maxTimestamp': result[0]['maxtimestamp'],
            }
        else:
            return {'error': 'no annotations made'}

    def get_sampleData(self, project):
        '''
            Returns a sample image from the project, with annotations
            (from one of the admins) and predictions.
            If no image, no annotations, and/or no predictions are
            available, a built-in default is returned instead.
        '''
        projImmutables = self.get_project_immutables(project)
        queryStr = self.sqlBuilder.getSampleDataQueryString(
            project, projImmutables['annotationType'],
            projImmutables['predictionType'])

        # query and parse results
        response = None
        with self.dbConnector.execute_cursor(queryStr, None) as cursor:
            try:
                response = self._assemble_annotations(project, cursor, True)
            except:
                pass

        if response is None or not len(response):
            # no valid data found for project; fall back to sample data
            response = {
                '00000000-0000-0000-0000-000000000000': {
                    'fileName':
                    '/static/interface/exampleData/sample_image.jpg',
                    'viewcount': 1,
                    'annotations': {
                        '00000000-0000-0000-0000-000000000000':
                        self._get_sample_metadata(
                            projImmutables['annotationType'])
                    },
                    'predictions': {
                        '00000000-0000-0000-0000-000000000000':
                        self._get_sample_metadata(
                            projImmutables['predictionType'])
                    },
                    'last_checked': None,
                    'isGoldenQuestion': True
                }
            }
        return response

    def submitAnnotations(self, project, username, submissions):
        '''
            Sends user-provided annotations to the database.
        '''
        projImmutables = self.get_project_immutables(project)
        if projImmutables['demoMode']:
            return 1

        # assemble values
        colnames = getattr(QueryStrings_annotation,
                           projImmutables['annotationType']).value
        values_insert = []
        values_update = []

        meta = (None if not 'meta' in submissions else json.dumps(
            submissions['meta']))

        # for deletion: remove all annotations whose image ID matches but whose annotation ID is not among the submitted ones
        ids = []

        viewcountValues = []
        for imageKey in submissions['entries']:
            entry = submissions['entries'][imageKey]

            try:
                lastChecked = entry['timeCreated']
                lastTimeRequired = entry['timeRequired']
                if lastTimeRequired is None: lastTimeRequired = 0
            except:
                lastChecked = datetime.now(tz=pytz.utc)
                lastTimeRequired = 0

            try:
                numInteractions = int(entry['numInteractions'])
            except:
                numInteractions = 0

            if 'annotations' in entry and len(entry['annotations']):
                for annotation in entry['annotations']:
                    # assemble annotation values
                    annotationTokens = self.annoParser.parseAnnotation(
                        annotation)
                    annoValues = []
                    for cname in colnames:
                        if cname == 'id':
                            if cname in annotationTokens:
                                # cast and only append id if the annotation is an existing one
                                annoValues.append(UUID(
                                    annotationTokens[cname]))
                                ids.append(UUID(annotationTokens[cname]))
                        elif cname == 'image':
                            annoValues.append(UUID(imageKey))
                        elif cname == 'label' and annotationTokens[
                                cname] is not None:
                            annoValues.append(UUID(annotationTokens[cname]))
                        elif cname == 'timeCreated':
                            try:
                                annoValues.append(
                                    dateutil.parser.parse(
                                        annotationTokens[cname]))
                            except:
                                annoValues.append(datetime.now(tz=pytz.utc))
                        elif cname == 'timeRequired':
                            timeReq = annotationTokens[cname]
                            if timeReq is None: timeReq = 0
                            annoValues.append(timeReq)
                        elif cname == 'username':
                            annoValues.append(username)
                        elif cname in annotationTokens:
                            annoValues.append(annotationTokens[cname])
                        elif cname == 'unsure':
                            if 'unsure' in annotationTokens and annotationTokens[
                                    'unsure'] is not None:
                                annoValues.append(annotationTokens[cname])
                            else:
                                annoValues.append(False)
                        elif cname == 'meta':
                            annoValues.append(meta)
                        else:
                            annoValues.append(None)
                    if 'id' in annotationTokens:
                        # existing annotation; update
                        values_update.append(tuple(annoValues))
                    else:
                        # new annotation
                        values_insert.append(tuple(annoValues))

            viewcountValues.append(
                (username, imageKey, 1, lastChecked, lastChecked,
                 lastTimeRequired, lastTimeRequired, numInteractions, meta))

        # delete all annotations that are not in submitted batch
        imageKeys = list(UUID(k) for k in submissions['entries'])
        if len(imageKeys):
            if len(ids):
                queryStr = sql.SQL('''
                    DELETE FROM {id_anno} WHERE username = %s AND id IN (
                        SELECT idQuery.id FROM (
                            SELECT * FROM {id_anno} WHERE id NOT IN %s
                        ) AS idQuery
                        JOIN (
                            SELECT * FROM {id_anno} WHERE image IN %s
                        ) AS imageQuery ON idQuery.id = imageQuery.id);
                ''').format(id_anno=sql.Identifier(project, 'annotation'))
                self.dbConnector.execute(queryStr, (
                    username,
                    tuple(ids),
                    tuple(imageKeys),
                ))
            else:
                # no annotations submitted; delete all annotations submitted before
                queryStr = sql.SQL('''
                    DELETE FROM {id_anno} WHERE username = %s AND image IN %s;
                ''').format(id_anno=sql.Identifier(project, 'annotation'))
                self.dbConnector.execute(queryStr, (
                    username,
                    tuple(imageKeys),
                ))

        # insert new annotations
        if len(values_insert):
            queryStr = sql.SQL('''
                INSERT INTO {id_anno} ({cols})
                VALUES %s ;
            ''').format(
                id_anno=sql.Identifier(project, 'annotation'),
                cols=sql.SQL(', ').join([sql.SQL(c) for c in colnames[1:]
                                         ])  # skip 'id' column
            )
            self.dbConnector.insert(queryStr, values_insert)

        # update existing annotations
        if len(values_update):

            updateCols = []
            for col in colnames:
                if col == 'label':
                    updateCols.append(sql.SQL('label = UUID(e.label)'))
                elif col == 'timeRequired':
                    # we sum the required times together
                    updateCols.append(
                        sql.SQL(
                            'timeRequired = COALESCE(a.timeRequired,0) + COALESCE(e.timeRequired,0)'
                        ))
                else:
                    updateCols.append(
                        sql.SQL('{col} = e.{col}').format(col=sql.SQL(col)))

            queryStr = sql.SQL('''
                UPDATE {id_anno} AS a
                SET {updateCols}
                FROM (VALUES %s) AS e({colnames})
                WHERE e.id = a.id
            ''').format(id_anno=sql.Identifier(project, 'annotation'),
                        updateCols=sql.SQL(', ').join(updateCols),
                        colnames=sql.SQL(', ').join(
                            [sql.SQL(c) for c in colnames]))

            self.dbConnector.insert(queryStr, values_update)

        # viewcount table
        queryStr = sql.SQL('''
            INSERT INTO {id_iu} (username, image, viewcount, first_checked, last_checked, last_time_required, total_time_required, num_interactions, meta)
            VALUES %s 
            ON CONFLICT (username, image) DO UPDATE SET viewcount = image_user.viewcount + 1,
                last_checked = EXCLUDED.last_checked,
                last_time_required = EXCLUDED.last_time_required,
                total_time_required = EXCLUDED.total_time_required + image_user.total_time_required,
                num_interactions = EXCLUDED.num_interactions + image_user.num_interactions,
                meta = EXCLUDED.meta;
        ''').format(id_iu=sql.Identifier(project, 'image_user'))
        self.dbConnector.insert(queryStr, viewcountValues)

        return 0

    def setGoldenQuestions(self, project, submissions):
        '''
            Receives an iterable of tuples (uuid, bool) and updates the
            property "isGoldenQuestion" of the images accordingly.
        '''
        projImmutables = self.get_project_immutables(project)
        if projImmutables['demoMode']:
            return 1

        queryStr = sql.SQL('''
            UPDATE {id_img} AS img SET isGoldenQuestion = c.isGoldenQuestion
            FROM (VALUES %s)
            AS c (id, isGoldenQuestion)
            WHERE c.id = img.id;
        ''').format(id_img=sql.Identifier(project, 'image'))
        self.dbConnector.insert(queryStr, submissions)

        return 0
Esempio n. 3
0
class DataWorker:

    FILENAMES_PROHIBITED_CHARS = (
        '<',
        '<',
        '>',
        '&gt;',
        '..',
        '/',
        '\\',
        '|',
        '?',
        '*',
        ':'  # for macOS
    )

    NUM_IMAGES_LIMIT = 4096  # maximum number of images that can be queried at once (to avoid bottlenecks)

    def __init__(self, config, passiveMode=False):
        self.config = config
        self.dbConnector = Database(config)
        self.countPattern = re.compile('\_[0-9]+$')
        self.passiveMode = passiveMode

        self.tempDir = self.config.getProperty('FileServer',
                                               'tempfiles_dir',
                                               type=str,
                                               fallback=tempfile.gettempdir())

    def aide_internal_notify(self, message):
        '''
            Used for AIDE administrative communication,
            e.g. for setting up queues.
        '''
        if self.passiveMode:
            return
        if 'task' in message:
            if message['task'] == 'create_project_folders':
                # set up folders for a newly created project
                if 'projectName' in message:
                    destPath = os.path.join(
                        self.config.getProperty('FileServer',
                                                'staticfiles_dir'),
                        message['projectName'])
                    os.makedirs(destPath, exist_ok=True)

    ''' Image administration functionalities '''

    def listImages(self,
                   project,
                   folder=None,
                   imageAddedRange=None,
                   lastViewedRange=None,
                   viewcountRange=None,
                   numAnnoRange=None,
                   numPredRange=None,
                   orderBy=None,
                   order='desc',
                   startFrom=None,
                   limit=None):
        '''
            Returns a list of images, with ID, filename,
            date image was added, viewcount, number of annotations,
            number of predictions, and last time viewed, for a given
            project.
            The list can be filtered by all those properties (e.g. 
            date and time image was added, last checked; number of
            annotations, etc.), as well as limited in length (images
            are sorted by date_added).
        '''
        queryArgs = []

        filterStr = ''
        if folder is not None and isinstance(folder, str):
            filterStr += ' filename LIKE %s '
            queryArgs.append(folder + '%')
        if imageAddedRange is not None:  #TODO
            filterStr += 'AND date_added >= to_timestamp(%s) AND date_added <= to_timestamp(%s) '
            queryArgs.append(imageAddedRange[0])
            queryArgs.append(imageAddedRange[1])
        if lastViewedRange is not None:  #TODO
            filterStr += 'AND last_viewed >= to_timestamp(%s) AND last_viewed <= to_timestamp(%s) '
            queryArgs.append(lastViewedRange[0])
            queryArgs.append(lastViewedRange[1])
        if viewcountRange is not None:
            filterStr += 'AND viewcount >= %s AND viewcount <= %s '
            queryArgs.append(viewcountRange[0])
            queryArgs.append(viewcountRange[1])
        if numAnnoRange is not None:
            filterStr += 'AND num_anno >= %s AND numAnno <= %s '
            queryArgs.append(numAnnoRange[0])
            queryArgs.append(numAnnoRange[1])
        if numPredRange is not None:
            filterStr += 'AND num_pred >= %s AND num_pred <= %s '
            queryArgs.append(numPredRange[0])
            queryArgs.append(numPredRange[1])
        if startFrom is not None:
            if not isinstance(startFrom, UUID):
                try:
                    startFrom = UUID(startFrom)
                except:
                    startFrom = None
            if startFrom is not None:
                filterStr += ' AND img.id > %s '
                queryArgs.append(startFrom)
        filterStr = filterStr.strip()
        if filterStr.startswith('AND'):
            filterStr = filterStr[3:]
        if len(filterStr.strip()):
            filterStr = 'WHERE ' + filterStr
        filterStr = sql.SQL(filterStr)

        orderStr = sql.SQL('ORDER BY img.id ASC')
        if orderBy is not None:
            orderStr = sql.SQL('ORDER BY {} {}, img.id ASC').format(
                sql.SQL(orderBy), sql.SQL(order))

        limitStr = sql.SQL('')
        if isinstance(limit, float):
            if not math.isnan(limit):
                limit = int(limit)
            else:
                limit = self.NUM_IMAGES_LIMIT
        elif isinstance(limit, str):
            try:
                limit = int(limit)
            except:
                limit = self.NUM_IMAGES_LIMIT
        elif not isinstance(limit, int):
            limit = self.NUM_IMAGES_LIMIT
        limit = max(min(limit, self.NUM_IMAGES_LIMIT), 1)
        limitStr = sql.SQL('LIMIT %s')
        queryArgs.append(limit)

        queryStr = sql.SQL('''
            SELECT img.id, filename, EXTRACT(epoch FROM date_added) AS date_added,
                COALESCE(viewcount, 0) AS viewcount,
                EXTRACT(epoch FROM last_viewed) AS last_viewed,
                COALESCE(num_anno, 0) AS num_anno,
                COALESCE(num_pred, 0) AS num_pred,
                img.isGoldenQuestion
            FROM {id_img} AS img
            FULL OUTER JOIN (
                SELECT image, COUNT(*) AS viewcount, MAX(last_checked) AS last_viewed
                FROM {id_iu}
                GROUP BY image
            ) AS iu
            ON img.id = iu.image
            FULL OUTER JOIN (
                SELECT image, COUNT(*) AS num_anno
                FROM {id_anno}
                GROUP BY image
            ) AS anno
            ON img.id = anno.image
            FULL OUTER JOIN (
                SELECT image, COUNT(*) AS num_pred
                FROM {id_pred}
                GROUP BY image
            ) AS pred
            ON img.id = pred.image
            {filter}
            {order}
            {limit}
        ''').format(id_img=sql.Identifier(project, 'image'),
                    id_iu=sql.Identifier(project, 'image_user'),
                    id_anno=sql.Identifier(project, 'annotation'),
                    id_pred=sql.Identifier(project, 'prediction'),
                    filter=filterStr,
                    order=orderStr,
                    limit=limitStr)

        result = self.dbConnector.execute(queryStr, tuple(queryArgs), 'all')
        for idx in range(len(result)):
            result[idx]['id'] = str(result[idx]['id'])
        return result

    def uploadImages(self,
                     project,
                     images,
                     existingFiles='keepExisting',
                     splitImages=False,
                     splitProperties=None):
        '''
            Receives a dict of files (bottle.py file format),
            verifies their file extension and checks if they
            are loadable by PIL.
            If they are, they are saved to disk in the project's
            image folder, and registered in the database.
            Parameter "existingFiles" can be set as follows:
            - "keepExisting" (default): if an image already exists on
              disk with the same path/file name, the new image will be
              renamed with an underscore and trailing number.
            - "skipExisting": do not save images that already exist on
              disk under the same path/file name.
            - "replaceExisting": overwrite images that exist with the
              same path/file name. Note: in this case all existing anno-
              tations, predictions, and other metadata about those images,
              will be removed from the database.
            
            If "splitImages" is True, the uploaded images will be automati-
            cally divided into patches on a regular grid according to what
            is defined in "splitProperties". For example, the following
            definition:

                splitProperties = {
                    'patchSize': (800, 600),
                    'stride': (400, 300),
                    'tight': True
                }

            would divide the images into patches of size 800x600, with over-
            lap of 50% (denoted by the "stride" being half the "patchSize"),
            and with all patches completely inside the original image (para-
            meter "tight" makes the last patches to the far left and bottom
            of the image being fully inside the original image; they are shif-
            ted if needed).
            Instead of the full images, the patches are stored on disk and re-
            ferenced through the database. The name format for patches is
            "imageName_x_y.jpg", with "imageName" denoting the name of the ori-
            ginal image, and "x" and "y" the left and top position of the patch
            inside the original image.

            Returns image keys for images that were successfully
            saved, and keys and error messages for those that
            were not.
        '''
        imgPaths_valid = []
        imgs_valid = []
        imgs_warn = {}
        imgs_error = {}
        for key in images.keys():
            try:
                nextUpload = images[key]
                nextFileName = nextUpload.raw_filename
                #TODO: check if raw_filename is compatible with uploads made from Windows

                # check if correct file suffix
                _, ext = os.path.splitext(nextFileName)
                if not ext.lower() in valid_image_extensions:
                    raise Exception(f'Invalid file type (*{ext})')

                # check if loadable as image
                cache = io.BytesIO()
                nextUpload.save(cache)
                try:
                    image = Image.open(cache)
                except Exception:
                    raise Exception('File is not a valid image.')

                # prepare image(s) to save to disk
                parent, filename = os.path.split(nextFileName)
                destFolder = os.path.join(
                    self.config.getProperty('FileServer', 'staticfiles_dir'),
                    project, parent)
                os.makedirs(destFolder, exist_ok=True)

                images = []
                filenames = []

                if not splitImages:
                    # upload the single image directly
                    images.append(image)
                    filenames.append(filename)

                else:
                    # split image into patches instead
                    images, coords = split_image(image,
                                                 splitProperties['patchSize'],
                                                 splitProperties['stride'],
                                                 splitProperties['tight'])
                    bareFileName, ext = os.path.splitext(filename)
                    filenames = [
                        f'{bareFileName}_{c[0]}_{c[1]}{ext}' for c in coords
                    ]

                # register and save all the images
                for i in range(len(images)):
                    subImage = images[i]
                    subFilename = filenames[i]

                    absFilePath = os.path.join(destFolder, subFilename)

                    # check if an image with the same name does not already exist
                    newFileName = subFilename
                    fileExists = os.path.exists(absFilePath)
                    if fileExists:
                        if existingFiles == 'keepExisting':
                            # rename new file
                            while (os.path.exists(absFilePath)):
                                # rename file
                                fn, ext = os.path.splitext(newFileName)
                                match = self.countPattern.search(fn)
                                if match is None:
                                    newFileName = fn + '_1' + ext
                                else:
                                    # parse number
                                    number = int(fn[match.span()[0] +
                                                    1:match.span()[1]])
                                    newFileName = fn[:match.span(
                                    )[0]] + '_' + str(number + 1) + ext

                                absFilePath = os.path.join(
                                    destFolder, newFileName)
                                if not os.path.exists(absFilePath):
                                    imgs_warn[
                                        key] = 'An image with name "{}" already exists under given path on disk. Image has been renamed to "{}".'.format(
                                            subFilename, newFileName)

                        elif existingFiles == 'skipExisting':
                            # ignore new file
                            imgs_warn[
                                key] = f'Image "{newFileName}" already exists on disk and has been skipped.'
                            imgs_valid.append(key)
                            imgPaths_valid.append(
                                os.path.join(parent, newFileName))
                            continue

                        elif existingFiles == 'replaceExisting':
                            # overwrite new file; first remove metadata
                            queryStr = sql.SQL('''
                                DELETE FROM {id_iu}
                                WHERE image = (
                                    SELECT id FROM {id_img}
                                    WHERE filename = %s
                                );
                                DELETE FROM {id_anno}
                                WHERE image = (
                                    SELECT id FROM {id_img}
                                    WHERE filename = %s
                                );
                                DELETE FROM {id_pred}
                                WHERE image = (
                                    SELECT id FROM {id_img}
                                    WHERE filename = %s
                                );
                                DELETE FROM {id_img}
                                WHERE filename = %s;
                            ''').format(
                                id_iu=sql.Identifier(project, 'image_user'),
                                id_anno=sql.Identifier(project, 'annotation'),
                                id_pred=sql.Identifier(project, 'prediction'),
                                id_img=sql.Identifier(project, 'image'))
                            self.dbConnector.execute(queryStr,
                                                     tuple([nextFileName] * 4),
                                                     None)

                            # remove file
                            try:
                                os.remove(absFilePath)
                                imgs_warn[key] = 'Image "{}" already existed on disk and has been deleted.\n'.format(newFileName) + \
                                                    'All metadata (views, annotations, predictions) have been removed from the database.'
                            except:
                                imgs_warn[key] = 'Image "{}" already existed on disk but could not be deleted.\n'.format(newFileName) + \
                                                    'All metadata (views, annotations, predictions) have been removed from the database.'

                    # write to disk
                    fileParent, _ = os.path.split(absFilePath)
                    if len(fileParent):
                        os.makedirs(fileParent, exist_ok=True)
                    subImage.save(absFilePath)

                    imgs_valid.append(key)
                    imgPaths_valid.append(os.path.join(parent, newFileName))

            except Exception as e:
                imgs_error[key] = str(e)

        # register valid images in database
        if len(imgPaths_valid):
            queryStr = sql.SQL('''
                INSERT INTO {id_img} (filename)
                VALUES %s
                ON CONFLICT (filename) DO NOTHING;
            ''').format(id_img=sql.Identifier(project, 'image'))
            self.dbConnector.insert(queryStr, [(i, ) for i in imgPaths_valid])

        result = {
            'imgs_valid': imgs_valid,
            'imgPaths_valid': imgPaths_valid,
            'imgs_warn': imgs_warn,
            'imgs_error': imgs_error
        }

        return result

    def scanForImages(self, project):
        '''
            Searches the project image folder on disk for
            files that are valid, but have not (yet) been added
            to the database.
            Returns a list of paths with files.
        '''

        # scan disk for files
        projectFolder = os.path.join(
            self.config.getProperty('FileServer', 'staticfiles_dir'), project)
        if (not os.path.isdir(projectFolder)) and (
                not os.path.islink(projectFolder)):
            # no folder exists for the project (should not happen due to broadcast at project creation)
            return []
        imgs_disk = listDirectory(projectFolder, recursive=True)

        # get all existing file paths from database
        imgs_database = set()
        queryStr = sql.SQL('''
            SELECT filename FROM {id_img};
        ''').format(id_img=sql.Identifier(project, 'image'))
        result = self.dbConnector.execute(queryStr, None, 'all')
        for r in range(len(result)):
            imgs_database.add(result[r]['filename'])

        # filter
        imgs_candidates = imgs_disk.difference(imgs_database)
        return list(imgs_candidates)

    def addExistingImages(self, project, imageList=None):
        '''
            Scans the project folder on the file system
            for images that are physically saved, but not
            (yet) added to the database.
            Adds them to the project's database schema.
            If an imageList iterable is provided, only
            the intersection between identified images on
            disk and in the iterable are added.

            If 'imageList' is a string with contents 'all',
            all untracked images will be added.

            Returns a list of image IDs and file names that
            were eventually added to the project database schema.
        '''
        # get all images on disk that are not in database
        imgs_candidates = self.scanForImages(project)

        if imageList is None or (isinstance(imageList, str)
                                 and imageList.lower() == 'all'):
            imgs_add = imgs_candidates
        else:
            if isinstance(imageList, dict):
                imageList = list(imageList.keys())
            imgs_add = list(set(imgs_candidates).intersection(set(imageList)))

        if not len(imgs_add):
            return 0, []

        # add to database
        queryStr = sql.SQL('''
            INSERT INTO {id_img} (filename)
            VALUES %s;
        ''').format(id_img=sql.Identifier(project, 'image'))
        self.dbConnector.insert(queryStr, tuple([(i, ) for i in imgs_add]))

        # get IDs of newly added images
        queryStr = sql.SQL('''
            SELECT id, filename FROM {id_img}
            WHERE filename IN %s;
        ''').format(id_img=sql.Identifier(project, 'image'))
        result = self.dbConnector.execute(queryStr, (tuple(imgs_add), ), 'all')

        status = (0 if result is not None and len(result) else 1)  #TODO
        return status, result

    def removeImages(self,
                     project,
                     imageList,
                     forceRemove=False,
                     deleteFromDisk=False):
        '''
            Receives an iterable of image IDs and removes them
            from the project database schema, including associated
            user views, annotations, and predictions made.
            Only removes entries if no user views, annotations, and
            predictions exist, or else if "forceRemove" is True.
            If "deleteFromDisk" is True, the image files are also
            deleted from the project directory on the file system.

            Returns a list of images that were deleted.
        '''

        imageList = tuple([(UUID(i), ) for i in imageList])

        queryArgs = []
        deleteArgs = []
        if forceRemove:
            queryStr = sql.SQL('''
                SELECT id, filename
                FROM {id_img}
                WHERE id IN %s;
            ''').format(id_img=sql.Identifier(project, 'image'))
            queryArgs = tuple([imageList])

            deleteStr = sql.SQL('''
                DELETE FROM {id_iu} WHERE image IN %s;
                DELETE FROM {id_anno} WHERE image IN %s;
                DELETE FROM {id_pred} WHERE image IN %s;
                DELETE FROM {id_img} WHERE id IN %s;
            ''').format(id_iu=sql.Identifier(project, 'image_user'),
                        id_anno=sql.Identifier(project, 'annotation'),
                        id_pred=sql.Identifier(project, 'prediction'),
                        id_img=sql.Identifier(project, 'image'))
            deleteArgs = tuple([imageList] * 4)

        else:
            queryStr = sql.SQL('''
                SELECT id, filename
                FROM {id_img}
                WHERE id IN %s
                AND id NOT IN (
                    SELECT image FROM {id_iu}
                    WHERE image IN %s
                    UNION ALL
                    SELECT image FROM {id_anno}
                    WHERE image IN %s
                    UNION ALL
                    SELECT image FROM {id_pred}
                    WHERE image IN %s
                );
            ''').format(id_img=sql.Identifier(project, 'image'),
                        id_iu=sql.Identifier(project, 'image_user'),
                        id_anno=sql.Identifier(project, 'annotation'),
                        id_pred=sql.Identifier(project, 'prediction'))
            queryArgs = tuple([imageList] * 4)

            deleteStr = sql.SQL('''
                DELETE FROM {id_img}
                WHERE id IN %s
                AND id NOT IN (
                    SELECT image FROM {id_iu}
                    WHERE image IN %s
                    UNION ALL
                    SELECT image FROM {id_anno}
                    WHERE image IN %s
                    UNION ALL
                    SELECT image FROM {id_pred}
                    WHERE image IN %s
                );
            ''').format(id_img=sql.Identifier(project, 'image'),
                        id_iu=sql.Identifier(project, 'image_user'),
                        id_anno=sql.Identifier(project, 'annotation'),
                        id_pred=sql.Identifier(project, 'prediction'))
            deleteArgs = tuple([imageList] * 4)

        # retrieve images to be deleted
        imgs_del = self.dbConnector.execute(queryStr, queryArgs, 'all')

        if imgs_del is None:
            imgs_del = []

        if len(imgs_del):
            # delete images
            self.dbConnector.execute(deleteStr, deleteArgs, None)

            if deleteFromDisk:
                projectFolder = os.path.join(
                    self.config.getProperty('FileServer', 'staticfiles_dir'),
                    project)
                if os.path.isdir(projectFolder) or os.path.islink(
                        projectFolder):
                    for i in imgs_del:
                        filePath = os.path.join(projectFolder, i['filename'])
                        if os.path.isfile(filePath):
                            os.remove(filePath)

            # convert UUID
            for idx in range(len(imgs_del)):
                imgs_del[idx]['id'] = str(imgs_del[idx]['id'])

        return imgs_del

    def removeOrphanedImages(self, project):
        '''
            Queries the project's image entries in the database and retrieves
            entries for which no image can be found on disk anymore. Removes
            and returns those entries and all associated (meta-) data from the
            database.
        '''
        imgs_DB = self.dbConnector.execute(
            sql.SQL('''
            SELECT id, filename FROM {id_img};
        ''').format(id_img=sql.Identifier(project, 'image')), None, 'all')

        projectFolder = os.path.join(
            self.config.getProperty('FileServer', 'staticfiles_dir'), project)
        if (not os.path.isdir(projectFolder)) and (
                not os.path.islink(projectFolder)):
            return []
        imgs_disk = listDirectory(projectFolder, recursive=True)
        imgs_disk = set(imgs_disk)

        # get orphaned images
        imgs_orphaned = []
        for i in imgs_DB:
            if i['filename'] not in imgs_disk:
                imgs_orphaned.append(i['id'])
        # imgs_orphaned = list(set(imgs_DB).difference(imgs_disk))
        if not len(imgs_orphaned):
            return []

        # remove
        self.dbConnector.execute(
            sql.SQL('''
            DELETE FROM {id_iu} WHERE image IN %s;
            DELETE FROM {id_anno} WHERE image IN %s;
            DELETE FROM {id_pred} WHERE image IN %s;
            DELETE FROM {id_img} WHERE id IN %s;
        ''').format(id_iu=sql.Identifier(project, 'image_user'),
                    id_anno=sql.Identifier(project, 'annotation'),
                    id_pred=sql.Identifier(project, 'prediction'),
                    id_img=sql.Identifier(project, 'image')),
            tuple([tuple(imgs_orphaned)] * 4), None)

        return imgs_orphaned

    def prepareDataDownload(self,
                            project,
                            dataType='annotation',
                            userList=None,
                            dateRange=None,
                            extraFields=None,
                            segmaskFilenameOptions=None,
                            segmaskEncoding='rgb'):
        '''
            Polls the database for project data according to the
            specified restrictions:
            - dataType: "annotation" or "prediction"
            - userList: for type "annotation": None (all users) or
                        an iterable of user names
            - dateRange: None (all dates) or two values for a mini-
                         mum and maximum timestamp
            - extraFields: None (no field) or dict of keywords and bools for
                           additional fields (e.g. browser meta) to be queried.
            - segmaskFilenameOptions: customization parameters for segmentation
                                      mask images' file names.
            - segmaskEncoding: encoding of the segmentation mask pixel
                               values ("rgb" or "indexed")
            
            Creates a file in this machine's temporary directory
            and returns the file name to it.
            Note that in some cases (esp. for semantic segmentation),
            the number of queryable entries may be limited due to
            file size and free disk space restrictions. An upper cei-
            ling is specified in the configuration *.ini file ('TODO')
        '''

        now = datetime.now(tz=pytz.utc)

        # argument check
        if userList is None:
            userList = []
        elif isinstance(userList, str):
            userList = [userList]
        if dateRange is None:
            dateRange = []
        elif len(dateRange) == 1:
            dateRange = [dateRange, now]

        if extraFields is None or not isinstance(extraFields, dict):
            extraFields = {'meta': False}
        else:
            if not 'meta' in extraFields or not isinstance(
                    extraFields['meta'], bool):
                extraFields['meta'] = False

        if segmaskFilenameOptions is None:
            segmaskFilenameOptions = {
                'baseName': 'filename',
                'prefix': '',
                'suffix': ''
            }
        else:
            if not 'baseName' in segmaskFilenameOptions or \
                segmaskFilenameOptions['baseName'] not in ('filename', 'id'):
                segmaskFilenameOptions['baseName'] = 'filename'
            try:
                segmaskFilenameOptions['prefix'] = str(
                    segmaskFilenameOptions['prefix'])
            except:
                segmaskFilenameOptions['prefix'] = ''
            try:
                segmaskFilenameOptions['suffix'] = str(
                    segmaskFilenameOptions['suffix'])
            except:
                segmaskFilenameOptions['suffix'] = ''

            for char in self.FILENAMES_PROHIBITED_CHARS:
                segmaskFilenameOptions['prefix'] = segmaskFilenameOptions[
                    'prefix'].replace(char, '_')
                segmaskFilenameOptions['suffix'] = segmaskFilenameOptions[
                    'suffix'].replace(char, '_')

        # check metadata type: need to deal with segmentation masks separately
        if dataType == 'annotation':
            metaField = 'annotationtype'
        elif dataType == 'prediction':
            metaField = 'predictiontype'
        else:
            raise Exception('Invalid dataType specified ({})'.format(dataType))
        metaType = self.dbConnector.execute(
            '''
                SELECT {} FROM aide_admin.project
                WHERE shortname = %s;
            '''.format(metaField), (project, ), 1)[0][metaField]

        if metaType.lower() == 'segmentationmasks':
            is_segmentation = True
            fileExtension = '.zip'

            # create indexed color palette for segmentation masks
            if segmaskEncoding == 'indexed':
                try:
                    indexedColors = []
                    labelClasses = self.dbConnector.execute(
                        sql.SQL('''
                            SELECT idx, color FROM {id_lc} ORDER BY idx ASC;
                        ''').format(
                            id_lc=sql.Identifier(project, 'labelclass')), None,
                        'all')
                    currentIndex = 1
                    for lc in labelClasses:
                        if lc['idx'] == 0:
                            # background class
                            continue
                        while currentIndex < lc['idx']:
                            # gaps in label classes; fill with zeros
                            indexedColors.extend([0, 0, 0])
                            currentIndex += 1
                        color = lc['color']
                        if color is None:
                            # no color specified; add from defaults
                            #TODO
                            indexedColors.extend([0, 0, 0])
                        else:
                            # convert to RGB format
                            indexedColors.extend(helpers.hexToRGB(color))

                except:
                    # an error occurred; don't convert segmentation mask to indexed colors
                    indexedColors = None
            else:
                indexedColors = None

        else:
            is_segmentation = False
            fileExtension = '.txt'  #TODO: support JSON?

        # prepare output file
        filename = 'aide_query_{}'.format(
            now.strftime('%Y-%m-%d_%H-%M-%S')) + fileExtension
        destPath = os.path.join(self.tempDir, 'aide/downloadRequests', project)
        os.makedirs(destPath, exist_ok=True)
        destPath = os.path.join(destPath, filename)

        # generate query
        queryArgs = []
        tableID = sql.Identifier(project, dataType)
        userStr = sql.SQL('')
        iuStr = sql.SQL('')
        dateStr = sql.SQL('')
        queryFields = [
            'filename',
            'isGoldenQuestion',
            'date_image_added',
            'last_requested_image',
            'image_corrupt'  # default image fields
        ]
        if dataType == 'annotation':
            iuStr = sql.SQL('''
                JOIN (SELECT image AS iu_image, username AS iu_username, viewcount, last_checked, last_time_required FROM {id_iu}) AS iu
                ON t.image = iu.iu_image
                AND t.username = iu.iu_username
            ''').format(id_iu=sql.Identifier(project, 'image_user'))
            if len(userList):
                userStr = sql.SQL('WHERE username IN %s')
                queryArgs.append(tuple(userList))

            queryFields.extend(
                getattr(QueryStrings_annotation, metaType).value)
            queryFields.extend([
                'username', 'viewcount', 'last_checked', 'last_time_required'
            ])  #TODO: make customizable

        else:
            queryFields.extend(
                getattr(QueryStrings_prediction, metaType).value)

        if len(dateRange):
            if len(userStr.string):
                dateStr = sql.SQL(
                    ' AND timecreated >= to_timestamp(%s) AND timecreated <= to_timestamp(%s)'
                )
            else:
                dateStr = sql.SQL(
                    'WHERE timecreated >= to_timestamp(%s) AND timecreated <= to_timestamp(%s)'
                )
            queryArgs.extend(dateRange)

        if not is_segmentation:
            # join label classes
            lcStr = sql.SQL('''
                JOIN (SELECT id AS lcID, name AS labelclass_name, idx AS labelclass_index
                    FROM {id_lc}
                ) AS lc
                ON label = lc.lcID
            ''').format(id_lc=sql.Identifier(project, 'labelclass'))
            queryFields.extend(['labelclass_name', 'labelclass_index'])
        else:
            lcStr = sql.SQL('')

        # remove redundant query fields
        queryFields = set(queryFields)
        for key in extraFields.keys():
            if not extraFields[key]:
                queryFields.remove(key)
        queryFields = list(queryFields)

        queryStr = sql.SQL('''
            SELECT * FROM {tableID} AS t
            JOIN (
                SELECT id AS imgID, filename, isGoldenQuestion, date_added AS date_image_added, last_requested AS last_requested_image, corrupt AS image_corrupt
                FROM {id_img}
            ) AS img ON t.image = img.imgID
            {lcStr}
            {iuStr}
            {userStr}
            {dateStr}
        ''').format(tableID=tableID,
                    id_img=sql.Identifier(project, 'image'),
                    lcStr=lcStr,
                    iuStr=iuStr,
                    userStr=userStr,
                    dateStr=dateStr)

        # query and process data
        if is_segmentation:
            mainFile = zipfile.ZipFile(destPath, 'w', zipfile.ZIP_DEFLATED)
        else:
            mainFile = open(destPath, 'w')
        metaStr = '; '.join(queryFields) + '\n'

        with self.dbConnector.execute_cursor(queryStr,
                                             tuple(queryArgs)) as cursor:
            while True:
                b = cursor.fetchone()
                if b is None:
                    break

                if is_segmentation:
                    # convert and store segmentation mask separately
                    segmask_filename = 'segmentation_masks/'

                    if segmaskFilenameOptions['baseName'] == 'id':
                        innerFilename = b['image']
                        parent = ''
                    else:
                        innerFilename = b['filename']
                        parent, innerFilename = os.path.split(innerFilename)
                    finalFilename = os.path.join(
                        parent,
                        segmaskFilenameOptions['prefix'] + innerFilename +
                        segmaskFilenameOptions['suffix'] + '.tif')
                    segmask_filename += finalFilename

                    segmask = base64ToImage(b['segmentationmask'], b['width'],
                                            b['height'])

                    if indexedColors is not None and len(indexedColors) > 0:
                        # convert to indexed color and add color palette from label classes
                        segmask = segmask.convert('RGB').convert(
                            'P', palette=Image.ADAPTIVE, colors=3)
                        segmask.putpalette(indexedColors)

                    # save
                    bio = io.BytesIO()
                    segmask.save(bio, 'TIFF')
                    mainFile.writestr(segmask_filename, bio.getvalue())

                # store metadata
                metaLine = ''
                for field in queryFields:
                    if field.lower() == 'segmentationmask':
                        continue
                    metaLine += '{}; '.format(b[field.lower()])
                metaStr += metaLine + '\n'

        if is_segmentation:
            mainFile.writestr('query.txt', metaStr)
        else:
            mainFile.write(metaStr)

        if is_segmentation:
            # append separate text file for label classes
            labelclassQuery = sql.SQL('''
                SELECT id, name, color, labelclassgroup, idx AS labelclass_index
                FROM {id_lc};
            ''').format(id_lc=sql.Identifier(project, 'labelclass'))
            result = self.dbConnector.execute(labelclassQuery, None, 'all')
            lcStr = 'id,name,color,labelclassgroup,labelclass_index\n'
            for r in result:
                lcStr += '{},{},{},{},{}\n'.format(r['id'], r['name'],
                                                   r['color'],
                                                   r['labelclassgroup'],
                                                   r['labelclass_index'])
            mainFile.writestr('labelclasses.csv', lcStr)

        mainFile.close()

        return filename

    def watchImageFolders(self):
        '''
            Queries all projects that have the image folder watch functionality
            enabled and updates the projects, one by one, with the latest image
            changes.
        '''
        projects = self.dbConnector.execute(
            '''
                SELECT shortname, watch_folder_remove_missing_enabled
                FROM aide_admin.project
                WHERE watch_folder_enabled IS TRUE;
            ''', None, 'all')

        for p in projects:
            pName = p['shortname']

            # add new images
            _, imgs_added = self.addExistingImages(pName, None)

            # remove orphaned images (if enabled)
            if p['watch_folder_remove_missing_enabled']:
                imgs_orphaned = self.removeOrphanedImages(pName)
                if len(imgs_added) or len(imgs_orphaned):
                    print(
                        f'[Project {pName}] {len(imgs_added)} new images found and added, {len(imgs_orphaned)} orphaned images removed from database.'
                    )

            elif len(imgs_added):
                print(
                    f'[Project {pName}] {len(imgs_added)} new images found and added.'
                )

    def deleteProject(self, project, deleteFiles=False):
        '''
            Irreproducibly deletes a project, including all data and metadata, from the database.
            If "deleteFiles" is True, then any data on disk (images, etc.) are also deleted.

            This cannot be undone.
        '''
        print(f'Deleting project with shortname "{project}"...')

        # remove database entries
        print('\tRemoving database entries...')
        self.dbConnector.execute(
            '''
            DELETE FROM aide_admin.authentication
            WHERE project = %s;
            DELETE FROM aide_admin.project
            WHERE shortname = %s;
        ''', (
                project,
                project,
            ), None
        )  # already done by DataAdministration.middleware, but we do it again to be sure

        self.dbConnector.execute('''
            DROP SCHEMA IF EXISTS {} CASCADE;
        '''.format(project), None, None)  #TODO: Identifier?

        if deleteFiles:
            print('\tRemoving files...')

            messages = []

            def _onError(function, path, excinfo):
                #TODO
                from celery.contrib import rdb
                rdb.set_trace()
                messages.append({
                    'function': function,
                    'path': path,
                    'excinfo': excinfo
                })

            try:
                shutil.rmtree(os.path.join(
                    self.config.getProperty('FileServer', 'staticfiles_dir'),
                    project),
                              onerror=_onError)
            except Exception as e:
                messages.append(str(e))

            return messages

        return 0
Esempio n. 4
0
class ProjectStatisticsMiddleware:
    def __init__(self, config):
        self.config = config
        self.dbConnector = Database(config)

    def getProjectStatistics(self, project):
        '''
            Returns statistics, such as number of images (seen),
            number of annotations, etc., on a global and per-user,
            but class-agnostic basis.
        '''
        queryStr = sql.SQL('''
            SELECT NULL AS username, COUNT(*) AS num_img, NULL::bigint AS num_anno FROM {id_img}
            UNION ALL
            SELECT NULL AS username, COUNT(DISTINCT(image)) AS num_img, NULL AS num_anno FROM {id_iu}
            UNION ALL
            SELECT NULL AS username, COUNT(DISTINCT(gq.id)) AS num_img, NULL AS num_anno FROM (
                SELECT id FROM {id_img} WHERE isGoldenQuestion = TRUE
            ) AS gq
            UNION ALL
            SELECT NULL AS username, NULL AS num_img, COUNT(DISTINCT(image)) AS num_anno FROM {id_anno}
            UNION ALL
            SELECT NULL AS username, NULL AS num_img, COUNT(*) AS num_anno FROM {id_anno}
            UNION ALL
            SELECT username, iu_cnt AS num_img, anno_cnt AS num_anno FROM (
            SELECT u.username, iu_cnt, anno_cnt
            FROM (
                SELECT DISTINCT(username) FROM (
                    SELECT username FROM {id_auth}
                    WHERE project = %s
                    UNION ALL
                    SELECT username FROM {id_iu}
                    UNION ALL
                    SELECT username FROM {id_anno}
                ) AS uQuery
            ) AS u
            LEFT OUTER JOIN (
                SELECT username, COUNT(*) AS iu_cnt
                FROM {id_iu}
                GROUP BY username
            ) AS iu
            ON u.username = iu.username
            LEFT OUTER JOIN (
                SELECT username, COUNT(*) AS anno_cnt
                FROM {id_anno}
                GROUP BY username
            ) AS anno
            ON u.username = anno.username
            ORDER BY u.username
        ) AS q;
        ''').format(id_img=sql.Identifier(project, 'image'),
                    id_iu=sql.Identifier(project, 'image_user'),
                    id_anno=sql.Identifier(project, 'annotation'),
                    id_auth=sql.Identifier('aide_admin', 'authentication'))
        result = self.dbConnector.execute(queryStr, (project, ), 'all')

        response = {
            'num_images': result[0]['num_img'],
            'num_viewed': result[1]['num_img'],
            'num_goldenQuestions': result[2]['num_img'],
            'num_annotated': result[3]['num_anno'],
            'num_annotations': result[4]['num_anno']
        }
        if len(result) > 5:
            response['user_stats'] = {}
            for i in range(5, len(result)):
                uStats = {
                    'num_viewed': result[i]['num_img'],
                    'num_annotations': result[i]['num_anno']
                }
                response['user_stats'][result[i]['username']] = uStats
        return response

    def getLabelclassStatistics(self, project):
        '''
            Returns annotation statistics on a per-label class
            basis.
            TODO: does not work for segmentationMasks (no label fields)
        '''
        queryStr = sql.SQL('''
            SELECT lc.name, COALESCE(num_anno, 0) AS num_anno, COALESCE(num_pred, 0) AS num_pred
            FROM {id_lc} AS lc
            FULL OUTER JOIN (
                SELECT label, COUNT(*) AS num_anno
                FROM {id_anno} AS anno
                GROUP BY label
            ) AS annoCnt
            ON lc.id = annoCnt.label
            FULL OUTER JOIN (
                SELECT label, COUNT(*) AS num_pred
                FROM {id_pred} AS pred
                GROUP BY label
            ) AS predCnt
            ON lc.id = predCnt.label
        ''').format(id_lc=sql.Identifier(project, 'labelclass'),
                    id_anno=sql.Identifier(project, 'annotation'),
                    id_pred=sql.Identifier(project, 'prediction'))
        result = self.dbConnector.execute(queryStr, None, 'all')

        response = {}
        if result is not None and len(result):
            for i in range(len(result)):
                nextResult = result[i]
                response[nextResult['name']] = {
                    'num_anno': nextResult['num_anno'],
                    'num_pred': nextResult['num_pred']
                }
        return response

    @staticmethod
    def _calc_geometric_stats(tp, fp, fn):
        tp, fp, fn = float(tp), float(fp), float(fn)
        try:
            precision = tp / (tp + fp)
        except:
            precision = 0.0
        try:
            recall = tp / (tp + fn)
        except:
            recall = 0.0
        try:
            f1 = 2 * precision * recall / (precision + recall)
        except:
            f1 = 0.0
        return precision, recall, f1

    def getPerformanceStatistics(self,
                                 project,
                                 entities_eval,
                                 entity_target,
                                 entityType='user',
                                 threshold=0.5,
                                 goldenQuestionsOnly=True):
        '''
            Compares the accuracy of a list of users or model states with a target
            user.
            The following measures of accuracy are reported, depending on the
            annotation type:
            - image labels: overall accuracy
            - points:
                    RMSE (distance to closest point with the same label; in pixels)
                    overall accuracy (labels)
            - bounding boxes:
                    IoU (max. with any target bounding box, regardless of label)
                    overall accuracy (labels)
            - segmentation masks:
                    TODO

            Value 'threshold' determines the geometric requirement for an annotation to be
            counted as correct (or incorrect) as follows:
                - points: maximum euclidean distance in pixels to closest target
                - bounding boxes: minimum IoU with best matching target

            If 'goldenQuestionsOnly' is True, only images with flag 'isGoldenQuestion' = True
            will be considered for evaluation.
        '''
        entityType = entityType.lower()

        # get annotation type for project
        annoType = self.dbConnector.execute(
            '''SELECT annotationType
            FROM aide_admin.project WHERE shortname = %s;''', (project, ), 1)
        annoType = annoType[0]['annotationtype']

        # for segmentation masks: get label classes and their ordinals      #TODO: implement per-class statistics for all types
        labelClasses = {}
        lcDef = self.dbConnector.execute(
            sql.SQL('''
            SELECT id, idx, color FROM {id_lc};
        ''').format(id_lc=sql.Identifier(project, 'labelclass')), None, 'all')
        if lcDef is not None:
            for l in lcDef:
                labelClasses[str(l['id'])] = (l['idx'], l['color'])

        else:
            # no label classes defined
            return {}

        # compose args list and complete query
        queryArgs = [entity_target, tuple(entities_eval)]
        if annoType == 'points' or annoType == 'boundingBoxes':
            queryArgs.append(threshold)
            if annoType == 'points':
                queryArgs.append(threshold)

        if goldenQuestionsOnly:
            sql_goldenQuestion = sql.SQL('''JOIN (
                    SELECT id
                    FROM {id_img}
                    WHERE isGoldenQuestion = true
                ) AS qi
                ON qi.id = q2.image''').format(
                id_img=sql.Identifier(project, 'image'))
        else:
            sql_goldenQuestion = sql.SQL('')

        # result tokens
        tokens = {}
        tokens_normalize = []
        if annoType == 'labels':
            tokens = {
                'num_matches': 0,
                'correct': 0,
                'incorrect': 0,
                'overall_accuracy': 0.0
            }
            tokens_normalize = ['overall_accuracy']
        elif annoType == 'points':
            tokens = {
                'num_pred': 0,
                'num_target': 0,
                'tp': 0,
                'fp': 0,
                'fn': 0,
                'avg_dist': 0.0
            }
            tokens_normalize = ['avg_dist']
        elif annoType == 'boundingBoxes':
            tokens = {
                'num_pred': 0,
                'num_target': 0,
                'tp': 0,
                'fp': 0,
                'fn': 0,
                'avg_iou': 0.0
            }
            tokens_normalize = ['avg_iou']
        elif annoType == 'segmentationMasks':
            tokens = {
                'num_matches': 0,
                'overall_accuracy': 0.0,
                'per_class': {}
            }
            for clID in labelClasses.keys():
                tokens['per_class'][clID] = {
                    'num_matches': 0,
                    'prec': 0.0,
                    'rec': 0.0,
                    'f1': 0.0
                }
            tokens_normalize = []

        if entityType == 'user':
            queryStr = getattr(StatisticalFormulas_user, annoType).value
            queryStr = sql.SQL(queryStr).format(
                id_anno=sql.Identifier(project, 'annotation'),
                id_iu=sql.Identifier(project, 'image_user'),
                sql_goldenQuestion=sql_goldenQuestion)

        else:
            queryStr = getattr(StatisticalFormulas_model, annoType).value
            queryStr = sql.SQL(queryStr).format(
                id_anno=sql.Identifier(project, 'annotation'),
                id_iu=sql.Identifier(project, 'image_user'),
                id_pred=sql.Identifier(project, 'prediction'),
                sql_goldenQuestion=sql_goldenQuestion)

        #TODO: update points query (according to bboxes); re-write stats parsing below

        # get stats
        response = {}
        with self.dbConnector.execute_cursor(queryStr,
                                             tuple(queryArgs)) as cursor:
            while True:
                b = cursor.fetchone()
                if b is None:
                    break

                if entityType == 'user':
                    entity = b['username']
                else:
                    entity = str(b['cnnstate'])

                if not entity in response:
                    response[entity] = tokens.copy()
                if annoType in ('points', 'boundingBoxes'):
                    response[entity]['num_matches'] = 1
                    if b['num_target'] > 0:
                        response[entity]['num_matches'] += 1

                if annoType == 'segmentationMasks':
                    # decode segmentation masks
                    try:
                        mask_target = np.array(
                            base64ToImage(b['q1segmask'], b['q1width'],
                                          b['q1height']))
                        mask_source = np.array(
                            base64ToImage(b['q2segmask'], b['q2width'],
                                          b['q2height']))

                        if mask_target.shape == mask_source.shape and np.any(
                                mask_target) and np.any(mask_source):

                            # calculate OA
                            intersection = (mask_target > 0) * (mask_source >
                                                                0)
                            if np.any(intersection):
                                oa = np.mean(mask_target[intersection] ==
                                             mask_source[intersection])
                                response[entity]['overall_accuracy'] += oa
                                response[entity]['num_matches'] += 1

                            # calculate per-class precision and recall values
                            for clID in labelClasses.keys():
                                idx = labelClasses[clID][0]
                                tp = np.sum((mask_target == idx) *
                                            (mask_source == idx))
                                fp = np.sum((mask_target != idx) *
                                            (mask_source == idx))
                                fn = np.sum((mask_target == idx) *
                                            (mask_source != idx))
                                if (tp + fp + fn) > 0:
                                    prec, rec, f1 = self._calc_geometric_stats(
                                        tp, fp, fn)
                                    response[entity]['per_class'][clID][
                                        'num_matches'] += 1
                                    response[entity]['per_class'][clID][
                                        'prec'] += prec
                                    response[entity]['per_class'][clID][
                                        'rec'] += rec
                                    response[entity]['per_class'][clID][
                                        'f1'] += f1

                    except Exception as e:
                        print(
                            f'TODO: error in segmentation mask statistics calculation ("{str(e)}").'
                        )

                else:
                    for key in tokens.keys():
                        if key == 'correct' or key == 'incorrect':
                            # classification
                            correct = b['label_correct']
                            # ignore None
                            if correct is True:
                                response[entity]['correct'] += 1
                                response[entity]['num_matches'] += 1
                            elif correct is False:
                                response[entity]['incorrect'] += 1
                                response[entity]['num_matches'] += 1
                        elif key in b and b[key] is not None:
                            response[entity][key] += b[key]

        for entity in response.keys():
            for t in tokens_normalize:
                if t in response[entity]:
                    if t == 'overall_accuracy':
                        response[entity][t] = float(response[entity]['correct']) / \
                            float(response[entity]['correct'] + response[entity]['incorrect'])
                    elif annoType in ('points', 'boundingBoxes'):
                        response[entity][t] /= response[entity]['num_matches']

            if annoType == 'points' or annoType == 'boundingBoxes':
                prec, rec, f1 = self._calc_geometric_stats(
                    response[entity]['tp'], response[entity]['fp'],
                    response[entity]['fn'])
                response[entity]['prec'] = prec
                response[entity]['rec'] = rec
                response[entity]['f1'] = f1

            elif annoType == 'segmentationMasks':
                # normalize OA
                response[entity]['overall_accuracy'] /= response[entity][
                    'num_matches']

                # normalize all label class values as well
                for lcID in labelClasses.keys():
                    numMatches = response[entity]['per_class'][lcID][
                        'num_matches']
                    if numMatches > 0:
                        response[entity]['per_class'][lcID][
                            'prec'] /= numMatches
                        response[entity]['per_class'][lcID][
                            'rec'] /= numMatches
                        response[entity]['per_class'][lcID]['f1'] /= numMatches

        return {'label_classes': labelClasses, 'per_entity': response}

    def getUserAnnotationSpeeds(self,
                                project,
                                users,
                                goldenQuestionsOnly=False):
        '''
            Returns, for each username in "users" list,
            the mean, median and lower and upper quartile
            (25% and 75%) of the time required in a given project.
        '''
        # prepare output
        response = {}
        for u in users:
            response[u] = {
                'avg': float('nan'),
                'median': float('nan'),
                'perc_25': float('nan'),
                'perc_75': float('nan')
            }

        if goldenQuestionsOnly:
            gqStr = sql.SQL('''
                JOIN {id_img} AS img
                ON anno.image = img.id
                WHERE img.isGoldenQuestion = true
            ''').format(id_img=sql.Identifier(project, 'image'))
        else:
            gqStr = sql.SQL('')

        queryStr = sql.SQL('''
            SELECT username, avg(timeRequired) AS avg,
            percentile_cont(0.50) WITHIN GROUP (ORDER BY timeRequired ASC) AS median,
            percentile_cont(0.25) WITHIN GROUP (ORDER BY timeRequired ASC) AS perc_25,
            percentile_cont(0.75) WITHIN GROUP (ORDER BY timeRequired ASC) AS perc_75
            FROM (
                SELECT username, timeRequired
                FROM {id_anno} AS anno
                {gqStr}
            ) AS q
            WHERE username IN %s
            GROUP BY username
        ''').format(id_anno=sql.Identifier(project, 'annotation'), gqStr=gqStr)
        result = self.dbConnector.execute(queryStr, (tuple(users), ), 'all')
        if result is not None:
            for r in result:
                user = r['username']
                response[user] = {
                    'avg': float(r['avg']),
                    'median': float(r['median']),
                    'perc_25': float(r['perc_25']),
                    'perc_75': float(r['perc_75']),
                }
        return response

    def getUserFinished(self, project, username):
        '''
            Returns True if the user has viewed all images in the project,
            and False otherwise.
            We deliberately do not reveal more information to the general
            user, in order to e.g. sustain the golden question limitation
            system.
        '''
        queryStr = sql.SQL('''
            SELECT COUNT(*) AS cnt FROM {id_iu}
            WHERE viewcount > 0 AND username = %s
            UNION ALL
            SELECT COUNT(*) AS cnt FROM {id_img};
        ''').format(id_img=sql.Identifier(project, 'image'),
                    id_iu=sql.Identifier(project, 'image_user'))
        result = self.dbConnector.execute(queryStr, (username, ), 2)
        return result[0]['cnt'] >= result[1]['cnt']

    def getTimeActivity(self,
                        project,
                        type='images',
                        numDaysMax=31,
                        perUser=False):
        '''
            Returns a histogram of the number of images viewed (if type = 'images')
            or annotations made (if type = 'annotations') over the last numDaysMax.
            If perUser is True, statistics are returned on a user basis.
        '''
        if type == 'images':
            id_table = sql.Identifier(project, 'image_user')
            time_field = sql.SQL('last_checked')
        else:
            id_table = sql.Identifier(project, 'annotation')
            time_field = sql.SQL('timeCreated')

        if perUser:
            userSpec = sql.SQL(', username')
        else:
            userSpec = sql.SQL('')
        queryStr = sql.SQL('''
            SELECT to_char({time_field}, 'YYYY-Mon-dd') AS month_day, MIN({time_field}) AS date_of_day, COUNT(*) AS cnt {user_spec}
            FROM {id_table}
            WHERE {time_field} IS NOT NULL
            GROUP BY month_day {user_spec}
            ORDER BY date_of_day ASC
            LIMIT %s
        ''').format(time_field=time_field,
                    id_table=id_table,
                    user_spec=userSpec)
        result = self.dbConnector.execute(queryStr, (numDaysMax, ), 'all')

        #TODO: homogenize series and add missing days

        if perUser:
            response = {}
        else:
            response = {'counts': [], 'timestamps': [], 'labels': []}

        for row in result:
            if perUser:
                if row['username'] not in response:
                    response[row['username']] = {
                        'counts': [],
                        'timestamps': [],
                        'labels': []
                    }
                response[row['username']]['counts'].append(row['cnt'])
                response[row['username']]['timestamps'].append(
                    row['date_of_day'].timestamp())
                response[row['username']]['labels'].append(row['month_day'])
            else:
                response['counts'].append(row['cnt'])
                response['timestamps'].append(row['date_of_day'].timestamp())
                response['labels'].append(row['month_day'])
        return response