Example #1
0
    def __init__(self, db_handler_or_db_uri):
        """
            Initializes a new DataManager via a given db_handler (s2sio.db.Dbhandler) or a
            db_uri (string)
            :param db_uri: the DbHandler or the database uri, e.g.
                "sqlite:///path_to_my_sqlite_file"
        """
        if isinstance(db_handler_or_db_uri, DbHandler):
            self.dbhandler = db_handler_or_db_uri
        else:
            self.dbhandler = DbHandler(db_handler_or_db_uri)

        self._classes_dataframe = self.dbhandler.read(self.dbhandler.T_CLS)
        self._classes_dataframe.insert(len(self._classes_dataframe.columns), 'Count', 0)
        self._classes_dataframe_needs_update = True
Example #2
0
class ClassAnnotator(object):
    """
        Class managing the annotation of class label on a downloaded database.
        An object of this class is initialized with a database handler
        s2sio.db.DbHandler or an URI (string) and allows setting / getting the labels
        for building training sets suitable for further machine learning processes
    """
    # the db column name of the annotated class_id:
    hl_class_id_colname = 'ClassIdHandLabeled'
    # the db column name of the (classified) class_id:
    class_id_colname = 'ClassId'

    def __init__(self, db_handler_or_db_uri):
        """
            Initializes a new DataManager via a given db_handler (s2sio.db.Dbhandler) or a
            db_uri (string)
            :param db_uri: the DbHandler or the database uri, e.g.
                "sqlite:///path_to_my_sqlite_file"
        """
        if isinstance(db_handler_or_db_uri, DbHandler):
            self.dbhandler = db_handler_or_db_uri
        else:
            self.dbhandler = DbHandler(db_handler_or_db_uri)

        self._classes_dataframe = self.dbhandler.read(self.dbhandler.T_CLS)
        self._classes_dataframe.insert(len(self._classes_dataframe.columns), 'Count', 0)
        self._classes_dataframe_needs_update = True

    def get_classes_df(self):
        """Returns the pandas DataFrame representing the classes. The DataFrame is read once
        in the constructor and updated here with a column 'Count' which counts the instances per
        class. Column names might vary across versions but in principle their names are 'Id',
        'Label' and 'Description' (plus the aforementioned 'Count')"""
        if self._classes_dataframe_needs_update:
            self._classes_dataframe_needs_update = False
            self.update_classes_count()
        return self._classes_dataframe

    def get_class(self, segment_id):
        """Returns the class id (integer) of the index-th item (=db table row).
        :param segment_id. The id (primary key) of a given segment. This method is mainly used
            in conjunction to a s2sio.db.ListReader object (which might have been passed to
            the constructor, as it extends s2sio.db.DbHandler). A ListReader is a DbHandler which
            reads and returns segments by their id in a python-list fashion
        """
        dbh = self.dbhandler
        where = (dbh.column(dbh.T_SEG, "Id") == segment_id)
        pddf = dbh.select([dbh.T_SEG], where)

        return pddf.iloc[0][self.class_id_colname]

    def set_class(self, segment_id, class_id):
        """
            Sets the class of the index-th item (=db table row).
            :param segment_id. The id (primary key) of a given segment. This method is mainly used
            in conjunction to a s2sio.db.ListReader object (which might have been passed to
            the constructor, as it extends s2sio.db.DbHandler). A ListReader is a DbHandler which
            reads and returns segments by their id in a python-list fashion
            :param class_id: one of the classes id. To get them, call self.get_classes()['Id']
            (returns a pandas Series object)
            :return: self.get_classes_df (updated with the new 'Count' column)
        """
        dbh = self.dbhandler
        # store the old class id and the new one:
        # NOTE: we absolutely need a with statement to keep the session open
        with dbh.session_scope() as session:
            id_colname = dbh.table_settings[dbh.T_SEG]['pkey']
            row = session.query(dbh.tables[dbh.T_SEG])
            row = row.filter(dbh.column(dbh.T_SEG, id_colname) == segment_id).first()
            # set new class id:
            setattr(row, self.class_id_colname, class_id)
            # set manual hand labelled to true, unless we set to "unknown" in which
            # case we reset the hand labelling to false (as if we didn't annotated it)
            setattr(row, self.hl_class_id_colname, class_id != UNKNOWN_CLASS_ID)

    def update_classes_count(self):
        """internal method used when calling set_class to update the counts of the given
        classes"""
        dbh = self.dbhandler
        session = dbh.session()
        classes_dataframe = self.get_classes_df()
        table = dbh.table(dbh.T_SEG)
        id_colname = dbh.table_settings[dbh.T_CLS]['pkey']

        def countfunc(row):
            """function setting the 'Count' column for each DataFrame row"""
            row['Count'] = session.query(table).filter(table.ClassId ==
                                                       row[id_colname]).count()
            return row

        self._classes_dataframe = classes_dataframe.apply(countfunc, axis=1)
        session.close()