Ejemplo n.º 1
0
 def test_dict(self):
     transform_group = {
         'image': lambda x: x[:, :, 0],
         'name': lambda _: 'hi'
     }
     sample = {'image': np.zeros((10, 10, 3)), 'objectid': 1, 'name': 'car'}
     utils.checkTransformGroup(transform_group)
Ejemplo n.º 2
0
    def __init__(self,
                 db_file,
                 rootdir='.',
                 where_image='TRUE',
                 where_object='TRUE',
                 mode='r',
                 copy_to_memory=True,
                 used_keys=None,
                 transform_group=None):
        '''
        Args:
            db_file:        (string) A path to an sqlite3 database file.
            rootdir:        (string) A root path, is pre-appended to "imagefile"
                            entries of "images" table in the sqlite3 database.
            where_image:    (string) The WHERE part of the SQL query on the 
                            "images" table, as in: 
                            "SELECT * FROM images WHERE ${where_image};"
                            Allows to query only needed images.
            where_object:   (string) The WHERE part of the SQL query on the 
                            "objects" table, as in: 
                            "SELECT * FROM objects WHERE ${where_object};"
                            Allows to query only needed objects for each image.
            mode:           ("r" or "w") The readonly or write-read mode to open 
                            the database. The default is "r", use "w" only if 
                            you plan to call addRecord().
            copy_to_memory: (bool) Copies database into memory. 
                            Only for mode="r". Should be used for python2. 
            used_keys:      (None or list of str)
                            Originally __getitem__ returns a dict with keys:
                            'image', 'mask', 'objects', 'imagefile', 'name',
                            'score' (see the comments to this class above).
                            Argument `used_keys` determines which of these keys
                            are needed, and which can be disposed of.
            transform_group: (a callable or a dict string -> callable) 
                            Transform(s) to be applied on a sample.
                            If it is a callable, it is applied to the sample.
                            If it is a dict, each key is matched to a key in the
                            sample, and callables are called on the respective
                            elements.
        '''

        self.mode = mode
        self.conn = utils.openConnection(db_file, mode, copy_to_memory)
        self.c = self.conn.cursor()
        utils.checkWhereImage(where_image)
        self.c.execute('SELECT * FROM images WHERE %s ORDER BY imagefile' %
                       where_image)
        self.image_entries = self.c.fetchall()

        self.imreader = backendMedia.MediaReader(rootdir=rootdir)

        utils.checkWhereObject(where_object)
        self.where_object = where_object

        _checkUsedKeys(used_keys)
        self.used_keys = used_keys
        utils.checkTransformGroup(transform_group)
        self.transform_group = transform_group
Ejemplo n.º 3
0
def _object_dataset_function(db_file,
                             rootdir='.',
                             where_object='TRUE',
                             used_keys=None,
                             transform_group=None):
    '''
    Args:
        db_file:        (string) A path to an sqlite3 database file.
        rootdir:        (string) A root path, is pre-appended to "imagefile"
                        entries of "images" table in the sqlite3 database.
        where_object:   (string) The WHERE part of the SQL query on the 
                        "objects" table, as in: 
                        "SELECT * FROM objects WHERE ${where_object};"
                        Allows to query only needed objects.
        used_keys:      (a list of strings or None) If specified, use only
                        these keys for every sample, and discard the rest.
        transform_group: A dict {string: callable}. Each key of this dict
                        should match a key in each sample, and the callables
                        are applied to the respective sample dict values.
    '''

    conn = utils.openConnection(db_file, 'r', copy_to_memory=False)
    c = conn.cursor()
    utils.checkWhereObject(where_object)
    c.execute('SELECT * FROM objects WHERE %s ORDER BY objectid' %
              where_object)
    object_entries = c.fetchall()

    imreader = backendMedia.MediaReader(rootdir=rootdir)

    _checkUsedKeys(used_keys)
    utils.checkTransformGroup(transform_group)

    samples = []
    logging.info('Loading samples...')
    for index in range(len(object_entries)):
        object_entry = object_entries[index]
        sample = utils.buildObjectSample(object_entry, c, imreader)
        if sample is None:
            continue
        sample = _filterKeys(used_keys, sample)
        sample = utils.applyTransformGroup(transform_group, sample)
        samples.append(sample)
    logging.info('Loaded %d samples.', len(object_entries))
    return samples
Ejemplo n.º 4
0
    def __init__(self,
                 db_file,
                 rootdir='.',
                 where_object='TRUE',
                 mode='r',
                 copy_to_memory=True,
                 used_keys=None,
                 transform_group=None,
                 preload_samples=False):
        '''
        Args:
            db_file:        (string) A path to an sqlite3 database file.
            rootdir:        (string) A root path, is pre-appended to "imagefile"
                            entries of "images" table in the sqlite3 database.
            where_object:   (string) The WHERE part of the SQL query on the 
                            "objects" table, as in: 
                            "SELECT * FROM objects WHERE ${where_object};"
                            Allows to query only needed objects.
            mode:           ("r" or "w") The readonly or write-read mode to open 
                            the database. The default is "r", use "w" only if 
                            you plan to call addRecord().
            copy_to_memory: (bool) Copies database into memory. 
                            Only for mode="r". Should be used for python2. 
            used_keys:      (a list of strings or None) If specified, use only
                            these keys for every sample, and discard the rest.
            transform_group: ((1) a callable, or (2) a list of callables, 
                            or (3) a dict {string: callable}) 
                            Transform(s) to be applied on a sample.
                            (1) A callable: It is applied to the sample.
                            (2) A list of callables: Each callable is applied 
                                to the sample sequentially.
                            (3) A dict {string: callable}: Each key of this dict
                            should match a key in each sample, and the callables
                            are applied to the respective sample dict values.
            preload_samples:  (bool) If true, will try to preload all samples
                            (including images) into memory in __init__.
        '''

        self.mode = mode
        self.conn = utils.openConnection(db_file, mode, copy_to_memory)
        self.c = self.conn.cursor()
        utils.checkWhereObject(where_object)
        self.c.execute('SELECT * FROM objects WHERE %s ORDER BY objectid' %
                       where_object)
        self.object_entries = self.c.fetchall()

        self.imreader = backendMedia.MediaReader(rootdir=rootdir)

        _checkUsedKeys(used_keys)
        self.used_keys = used_keys
        utils.checkTransformGroup(transform_group)
        self.transform_group = transform_group

        if not preload_samples:
            self.preloaded_samples = None
        else:
            self.preloaded_samples = []
            logging.info('Loading samples...')
            for index in range(len(self)):
                object_entry = self.object_entries[index]
                sample = utils.buildObjectSample(object_entry, self.c,
                                                 self.imreader)
                if sample is None:
                    logging.warning('Skip bad sample %d', index)
                else:
                    self.preloaded_samples.append(sample)
            logging.info('Loaded %d samples.', len(self))
Ejemplo n.º 5
0
    def __init__(self,
                 db_file,
                 rootdir='.',
                 where_image='TRUE',
                 where_object='TRUE',
                 mode='r',
                 copy_to_memory=True,
                 used_keys=None,
                 transform_group=None,
                 batch_size=1,
                 shuffle=False):
        '''
        Args:
            db_file:        (string) A path to an sqlite3 database file.
            rootdir:        (string) A root path, is pre-appended to "imagefile"
                            entries of "images" table in the sqlite3 database.
            where_image:    (string) The WHERE part of the SQL query on the 
                            "images" table, as in: 
                            "SELECT * FROM images WHERE ${where_image};"
                            Allows to query only needed images.
            where_object:   (string) The WHERE part of the SQL query on the 
                            "objects" table, as in: 
                            "SELECT * FROM objects WHERE ${where_object};"
                            Allows to query only needed objects for each image.
            mode:           ("r" or "w") The readonly or write-read mode to open 
                            the database. The default is "r", use "w" only if 
                            you plan to call addRecord().
            copy_to_memory: (bool) Copies database into memory. 
                            Only for mode="r". Should be used for python2. 
            used_keys:      (None, list of str, tuple of str, or dict str -> str)
                            Originally __getitem__ returns a dict with keys:
                            'image', 'mask', 'objects', 'imagefile', 'name',
                            'score' (see the comments to this class above).
                            Argument `used_keys` determines which of these keys
                            are needed, and which can be disposed of.
                            Options for `used_keys`.
                            1) None. Each sample is unchanged dict.
                            2) List of str. Each str is a key.
                               __getitem__ returns a list.
                            3) Tuple of str. Same as above.
                               __getitem__ returns a tuple.
                            4) Dict str -> str. The key is the key in the
                               database, the value is the key in the output dict.
                               __getitem__ returns a dict.
            transform_group: (a callable or a dict string -> callable) 
                            Transform(s) to be applied on a sample.
                            1) If it is a callable, it is applied to the sample.
                            2) If it is a dict, each key is matched to a key in 
                               the sample (after used dict), and callables are 
                               called on the respective elements.
        '''
        self.batch_size = batch_size
        self.shuffle = shuffle

        if not op.exists(db_file):
            raise ValueError('db_file does not exist: %s' % db_file)

        self.mode = mode
        self.conn = utils.openConnection(db_file, mode, copy_to_memory)
        self.c = self.conn.cursor()
        self.c.execute('SELECT * FROM images WHERE %s ORDER BY imagefile' %
                       where_image)
        self.image_entries = self.c.fetchall()

        self.imreader = backendMedia.MediaReader(rootdir=rootdir)
        self.where_object = where_object

        _checkUsedKeys(used_keys)
        self.used_keys = used_keys
        utils.checkTransformGroup(transform_group)
        self.transform_group = transform_group

        self.on_epoch_end()
Ejemplo n.º 6
0
 def test_list(self):
     transform_group = [lambda x: x['image'], lambda x: x['objectid']]
     sample = {'image': np.zeros(3), 'objectid': 1, 'name': 'car'}
     utils.checkTransformGroup(transform_group)
Ejemplo n.º 7
0
 def test_callable(self):
     transform_group = lambda x: x['image']
     sample = {'image': np.zeros(3), 'objectid': 1}
     utils.checkTransformGroup(transform_group)