def test_save(self): """Test ExtendedModel.save""" pk = SampleExtModel({ 'static': 'static_value', 'sample_attr_1': 57, 'sample_attr_2': 'fifty-seven', 'sample_attr_3': 57.0 }).save().id() select = DBSelect(('sample_attribute_integer', 'v'), ('value', )).join( ('attribute', 'a'), 'v.attribute = a._id', ()).where('a.code = ?', 'sample_attr_1').where('a.parent = ?', 'sample').where( 'a.type = ?', 'integer').where('v.parent = ?', pk).limit(1) self.assertEqual(select.query().fetchone()['value'], 57) select = DBSelect(('sample_attribute_text', 'v'), ('value', )).join( ('attribute', 'a'), 'v.attribute = a._id', ()).where('a.code = ?', 'sample_attr_2').where('a.parent = ?', 'sample').where( 'a.type = ?', 'text').where('v.parent = ?', pk).limit(1) self.assertEqual(select.query().fetchone()['value'], 'fifty-seven') select = DBSelect(('sample_attribute_real', 'v'), ('value', )).join( ('attribute', 'a'), 'v.attribute = a._id', ()).where('a.code = ?', 'sample_attr_3').where('a.parent = ?', 'sample').where( 'a.type = ?', 'real').where('v.parent = ?', pk).limit(1) # should fail since 'sample_attr_3' is in another group self.assertIsNone(select.query().fetchone())
def test_select_update(self): """Test executing UPDATEs based on SELECTs""" select = DBSelect('test_table').where('col_a = ?', 1) self.assertEqual(select.query_update({'col_b': 512}), 1) select = DBSelect('test_table').order('col_a').limit(3, 1) self.assertEqual(select.query_update({'col_b': 512}), 3) select = DBSelect('test_table').where('col_b = ?', 512) self.assertEqual(select.query_update({'col_b': 1024}), 4)
def test_select_join(self): """Test the DBSelects' *_join mehtods""" self._helper.query('CREATE TABLE a (x INTEGER)') self._helper.query('CREATE TABLE b (y INTEGER)') self._helper.insert('a', ({'x': 1}, {'x': 2}, {'x': 3})) self._helper.insert('b', ({'y': 2}, {'y': 3}, {'y': 4})) sql = DBSelect('a', {'n': 'COUNT(*)'}).inner_join('b', 'a.x = b.y', ()) self.assertTrue(self._helper.query(sql.render()).fetchone()['n'] == 2) sql = DBSelect('a', {'n': 'COUNT(*)'}).left_join('b', 'a.x = b.y', ()) self.assertTrue(self._helper.query(sql.render()).fetchone()['n'] == 3)
def test_select_delete(self): """Test executing DELETEs based on SELECTs""" select = (DBSelect('test_table') .order('col_a', 'DESC') .limit(self._num_rows - 10)) self.assertEqual(select.query_delete(), self._num_rows - 10) select = DBSelect('test_table').where('col_a = ?', 1) self.assertEqual(select.query_delete(), 1) select = DBSelect('test_table').where('col_a IN (?)', (1, 2, 3, 4)) self.assertEqual(select.query_delete(), 3)
def test_delete(self): """Test ExtendedModel.delete""" pk = SampleExtModel({ 'static': 'something', 'sample_attr_1': 57 }).save().id() # make sure model delete affects attribute tables int_attrs = DBSelect('sample_attribute_integer').query().fetchall() self.assertEqual(len(int_attrs), 1) SampleExtModel().load(pk).delete() int_attrs = DBSelect('sample_attribute_integer').query().fetchall() self.assertEqual(len(int_attrs), 0)
def test_install(self): """Test InstallHelper.install""" if os.path.isfile('/tmp/box.db'): os.unlink('/tmp/box.db') DBHelper().set_db('/tmp/box.db') InstallHelper.reset() module = 'sample_module' q = DBHelper.quote_identifier install_routines = [""" CREATE TABLE IF NOT EXISTS %s ( id INTEGER ) """ % q(module)] InstallHelper.install(module, install_routines) self.assertIn('id', DBHelper().describe_table(module)) self.assertEqual(InstallHelper.version(module), 1) install_routines.append( lambda: DBHelper().insert(module, [{'id': 1}, {'id': 2}])) InstallHelper.install(module, install_routines) self.assertEqual( DBSelect(module, {'c': 'COUNT(*)'}).query().fetchone()['c'], 2 ) self.assertEqual(InstallHelper.version(module), 2) install_routines.append( 'ALTER TABLE %s ADD COLUMN %s TEXT' % (q(module), q('col'))) InstallHelper.install(module, install_routines) self.assertIn('col', DBHelper().describe_table(module)) self.assertEqual(InstallHelper.version(module), 3) InstallHelper.reset() DBHelper().set_db(None) os.unlink('/tmp/box.db')
def test_select_unset(self): """Test unsetting parts of SELECT""" select = DBSelect('a', {'c': 'col'} ).left_join('b', 'a.c = b.d', {'d': 'col'} ).where('a.col = ?', 1 ).order('b.d', 'DESC' ).limit(1, 2 ).distinct(True) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT DISTINCT "a"."col" AS "c", "b"."col" AS "d" ' + 'FROM "a" LEFT JOIN "b" ON a.c = b.d ' + 'WHERE (a.col = ?) ' + 'ORDER BY "b"."d" DESC ' + 'LIMIT 2, 1' ) select.unset(select.FROM | select.COLUMNS).set_from('x') self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT DISTINCT "x".* ' + 'FROM "x" WHERE (a.col = ?) ORDER BY "b"."d" DESC LIMIT 2, 1' ) select.unset(select.WHERE) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT DISTINCT "x".* ' + 'FROM "x" ORDER BY "b"."d" DESC LIMIT 2, 1' ) select.unset(select.DISTINCT | select.ORDER | select.LIMIT) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "x".* FROM "x"' )
def test_select_where(self): """Test the WHERE part of DBSelect""" select = DBSelect('test_table') select.where('"col_a" = ?', 1) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "test_table".* FROM "test_table" WHERE ("col_a" = ?)' ) self.assertEqual(len(select.query().fetchall()), 1) select.or_where('"col_a" = ?', 2) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "test_table".* FROM "test_table" ' + 'WHERE ("col_a" = ?) OR ("col_a" = ?)' ) self.assertEqual(len(select.query().fetchall()), 2) select.where('"col_a" IN (?)', (3, 4, 5)) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "test_table".* ' + 'FROM "test_table" ' + 'WHERE ("col_a" = ?) ' + 'OR ("col_a" = ?) ' + 'AND ("col_a" IN (?, ?, ?))' ) self.assertEqual(len(select.query().fetchall()), 1)
def _db_select(self, key=None): """Return a DBSelect querying for this model""" if not key: key = self.__class__._pk where = '%s = ?' % (DBHelper.quote_identifier(key), ) return DBSelect(self.get_table()).where(where, self.get_data(key)).limit(1)
def get_all_filters(cls): """Present filters and options in a json encodable structure""" filters = [] for row in DBSelect(cls._filter_table).query().fetchall(): filter_object = { 'param': row['code'], 'label': row['label'], 'multi': True, 'options': {} } opt_select = DBSelect(cls._filter_option_table) opt_select.where('filter = ?', row['_id']) for opt in opt_select.query().fetchall(): filter_object['options'][opt['value']] = opt['label'] filters.append(filter_object) return filters
def get_time_attr_select(): """Return a DBSelect object with file timestamp values""" attr = FileTimeIndexer.get_time_attribute() table = '%s_attribute_%s' % (attr['parent'], attr['type']) return DBSelect((table, 'tt')).where( 'tt.attribute = ?', attr['_id']).where('tt.value IS NOT NULL').where('tt.value > 0')
def test_select_order(self): """Test the ORDER BY part of DBSelect""" self._helper.query('CREATE TABLE a (x INTEGER, y INTEGER)') self._helper.insert('a', ( {'x': 1, 'y': 1}, {'x': 2, 'y': 2}, {'x': 3, 'y': 2}, )) select = DBSelect('a').order('x') self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "a".* FROM "a" ORDER BY "x" ASC' ) self.assertEqual(select.query().fetchone()['x'], 1) select = DBSelect('a').order('y', 'DESC').order('x') self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "a".* FROM "a" ORDER BY "y" DESC, "x" ASC' ) self.assertEqual(select.query().fetchone()['x'], 2)
def _update_version(cls, module, version): """Update module version number""" cls._initialize() if module in cls._state: DBSelect(cls._table ).where('module = ?', module ).query_update({'version': version}) else: DBHelper().insert(cls._table, { 'module': module, 'version': version }) cls._initialize(True)
def test_select_distinct(self): """Test the DISTINCT directive""" select = DBSelect('a') select.distinct() self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT DISTINCT "a".* FROM "a"' ) select.distinct(False) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "a".* FROM "a"' )
def get_filter_id(cls, filter_id): """Find filter id from filter id or code""" if not hasattr(cls, '_filter_ids'): cls._filter_ids = {} if filter_id not in cls._filter_ids: filters = DBSelect(cls._filter_table).query().fetchall() for row in filters: cls._filter_ids[row['_id']] = row['_id'] cls._filter_ids[row['code']] = row['_id'] return (cls._filter_ids[filter_id] if filter_id in cls._filter_ids else None)
def test_modelset_len(self): """Test len(BaseModelSet)""" count = DBSelect( SampleModel._table, 'COUNT(*) as "c"' ).query().fetchone()['c'] self.assertEqual(len(SampleModel.all()), count) self.assertEqual(len(SampleModel.all().limit(1)), 1) self.assertEqual(len(SampleModel.all().limit(0)), 0) models = SampleModel.all().limit(1) self.assertEqual(len(models), 1) self.assertEqual(models.total_size(False), count) self.assertEqual(models.total_size(True), count)
def test_select_columns(self): """Test DBSelect.columns""" select = DBSelect('a', {}) select.columns({'c': 'b'}) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "a"."b" AS "c" FROM "a"' ) select.left_join('b', 'a.b = b.a', {}) select.columns('*', 'b') self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "a"."b" AS "c", "b".* FROM "a" LEFT JOIN "b" ON a.b = b.a' )
def test_select_from(self): """Test the SELECT <columns> FROM part of DBSelect""" select = DBSelect('a') sql = re.sub(r'\s+', ' ', select.render()) self.assertEqual(sql, 'SELECT "a".* FROM "a"') select.set_from('b', ()) sql = re.sub(r'\s+', ' ', select.render()) self.assertEqual(sql, 'SELECT "a".* FROM "b" INNER JOIN "a"') select.set_from('c', ('d', 'e')) sql = re.sub(r'\s+', ' ', select.render()) self.assertEqual( sql, 'SELECT "a".*, "c"."d", "c"."e" ' + 'FROM "c" INNER JOIN "b" INNER JOIN "a"' ) count = self._helper.query( DBSelect('test_table', {'count': 'COUNT(*)'}).render() ).fetchone() self.assertTrue(type(count['count']) == int) self.assertTrue(count['count'] >= 0)
def test_select_clone(self): """Test cloning a SELECT""" select1 = DBSelect('a', {'c': 'col'} ).left_join('b', 'a.c = b.d', {'d': 'col'} ).where('a.col = ?', 1 ).order('b.d', 'DESC' ).limit(1, 2 ).distinct(True) select2 = select1.clone() self.assertEqual(str(select1), str(select2)) select2.or_where('b.col = ?', 1) self.assertNotEqual(str(select1), str(select2)) select1.unset(select1.WHERE) select2.unset(select2.WHERE) self.assertEqual(str(select1), str(select2))
def _initialize(cls, force=False): """Load version info from db""" if cls._state and not force: return if not DBHelper().describe_table(cls._table): DBHelper().query(""" CREATE TABLE %s ( "module" TEXT PRIMARY KEY, "version" INTEGER NOT NULL DEFAULT 0 ) """ % DBHelper.quote_identifier(cls._table)) cls._state = {} for module in DBSelect(cls._table).query().fetchall(): cls._state[module['module']] = module['version']
def test_select_limit(self): """Test LIMIT and OFFSET parts""" select = DBSelect('test_table') select.limit(10) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "test_table".* FROM "test_table" LIMIT 10' ) self.assertEqual(len(select.query().fetchall()), 10) select.limit(10, self._num_rows - 5) self.assertEqual( re.sub(r'\s+', ' ', str(select)), 'SELECT "test_table".* FROM "test_table" LIMIT %d, 10' % (self._num_rows - 5,) ) self.assertEqual(len(select.query().fetchall()), 5)
def filter_has_option(cls, filter_id, value): """Check if a given filter option exists""" filter_id = cls.get_filter_id(filter_id) if not hasattr(cls, '_filter_options'): cls._filter_options = {} if (filter_id not in cls._filter_options or value not in cls._filter_options[filter_id]): options = DBSelect(cls._filter_option_table).query().fetchall() for row in options: if row['filter'] not in cls._filter_options: cls._filter_options[row['filter']] = [] cls._filter_options[row['filter']].append(row['value']) return (filter_id in cls._filter_options and value in cls._filter_options[filter_id])
def set_filter_values(cls, file_id, filter_id, values=None): """Set the values for the given file and filter""" if isinstance(file_id, FileModel): file_id = file_id.id() filter_id = cls.get_filter_id(filter_id) DBSelect(cls._filter_value_table).where('file = ?', file_id).where( 'filter = ?', filter_id).query_delete() if values is None: return elif type(values) not in (list, tuple): values = (values, ) for value in values: DBHelper().insert(cls._filter_value_table, { 'file': file_id, 'filter': filter_id, 'value': value })
def _get_attribute_group(cls, group_id, create=False): """Retrieve an attribute group id""" if (cls._table in cls._group_index and group_id in cls._group_index[cls._table]): return cls._group_index[cls._table][group_id] group_desc = { '_id': group_id, 'code': str(group_id).replace(' ', '_'), 'label': str(group_id.replace('_', ' ')).capitalize() } cls._group_index[cls._table] = {} for field in ('_id', 'code', 'label'): select = DBSelect('attribute_group').where( '%s = ?' % field, group_desc[field]).where( 'type = ?', cls._table).limit(1) group = select.query().fetchone() if type(group) is dict: cls._group_index[cls._table][group_id] = group['_id'] return cls._group_index[cls._table][group_id] if not create: return None ids = DBHelper().insert('attribute_group', { 'code': group_desc['code'], 'label': group_desc['label'], 'type': cls._table }) if len(ids) > 0: cls._group_index[cls._table][group_id] = ids[0] return cls._group_index[cls._table][group_id] return None
def get_all_attributes(cls, group=None): """Get attributes related to this model class""" parent = cls._table if not parent: return None cache_key = 'ATTRIBUTES_%s' % parent if group is not None: group = cls._get_attribute_group(group) if group is None: return None cache_key = '%s_%d' % (cache_key, group) attributes = cls._cache.get(cache_key) if type(attributes) is list: return attributes select = DBSelect(('attribute_group', 'g'), () ).inner_join( ('attribute_group_attribute', 'ga'), '"g"."_id" = "ga"."group"', () ).inner_join( ('attribute', 'a'), '"ga"."attribute" = "a"."_id"' ).where('"a"."parent" = ?', parent) if group is not None: select.where('"g"."_id" = ?', group) attributes = select.query().fetchall() if type(attributes) is list: cls._cache.set(cache_key, attributes) return attributes
def get_filter_values_select(cls, filter_id): """Return a DBSelect object with values for the given filter""" filter_id = cls.get_filter_id(filter_id) return DBSelect( (cls._filter_value_table, 'fv')).where('filter = ?', filter_id)