class MigrationsTestCase(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.drop_table(MigrationHistory)

    def tableExists(self, model_class):
        query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
        return next(self.database.select(query)).result == 1

    def getTableFields(self, model_class):
        query = "DESC `%s`.`%s`" % (self.database.db_name, model_class.table_name())
        return [(row.name, row.type) for row in self.database.select(query)]

    def test_migrations(self):
        self.database.migrate('tests.sample_migrations', 1)
        self.assertTrue(self.tableExists(Model1))
        self.database.migrate('tests.sample_migrations', 2)
        self.assertFalse(self.tableExists(Model1))
        self.database.migrate('tests.sample_migrations', 3)
        self.assertTrue(self.tableExists(Model1))
        self.assertEquals(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
        self.database.migrate('tests.sample_migrations', 4)
        self.assertEquals(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String')])
        self.database.migrate('tests.sample_migrations', 5)
        self.assertEquals(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')])
Exemple #2
0
class MigrationsTestCase(unittest.TestCase):
    def setUp(self):
        self.database = Database('test-db')
        self.database.drop_table(MigrationHistory)

    def tearDown(self):
        self.database.drop_database()

    def tableExists(self, model_class):
        query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
        return next(self.database.select(query)).result == 1

    def getTableFields(self, model_class):
        query = "DESC `%s`.`%s`" % (self.database.db_name,
                                    model_class.table_name())
        return [(row.name, row.type) for row in self.database.select(query)]

    def test_migrations(self):
        # Creation and deletion of table
        self.database.migrate('tests.sample_migrations', 1)
        self.assertTrue(self.tableExists(Model1))
        self.database.migrate('tests.sample_migrations', 2)
        self.assertFalse(self.tableExists(Model1))
        self.database.migrate('tests.sample_migrations', 3)
        self.assertTrue(self.tableExists(Model1))
        # Adding, removing and altering simple fields
        self.assertEquals(self.getTableFields(Model1), [('date', 'Date'),
                                                        ('f1', 'Int32'),
                                                        ('f2', 'String')])
        self.database.migrate('tests.sample_migrations', 4)
        self.assertEquals(self.getTableFields(Model2), [('date', 'Date'),
                                                        ('f1', 'Int32'),
                                                        ('f3', 'Float32'),
                                                        ('f2', 'String'),
                                                        ('f4', 'String')])
        self.database.migrate('tests.sample_migrations', 5)
        self.assertEquals(self.getTableFields(Model3), [('date', 'Date'),
                                                        ('f1', 'Int64'),
                                                        ('f3', 'Float64'),
                                                        ('f4', 'String')])
        # Altering enum fields
        self.database.migrate('tests.sample_migrations', 6)
        self.assertTrue(self.tableExists(EnumModel1))
        self.assertEquals(self.getTableFields(EnumModel1),
                          [('date', 'Date'),
                           ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")])
        self.database.migrate('tests.sample_migrations', 7)
        self.assertTrue(self.tableExists(EnumModel1))
        self.assertEquals(
            self.getTableFields(EnumModel2),
            [('date', 'Date'),
             ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")])
        self.database.migrate('tests.sample_migrations', 8)
        self.assertTrue(self.tableExists(MaterializedModel))
        self.assertEquals(self.getTableFields(MaterializedModel),
                          [('date_time', "DateTime"), ('date', 'Date')])
        self.database.migrate('tests.sample_migrations', 9)
        self.assertTrue(self.tableExists(AliasModel))
        self.assertEquals(self.getTableFields(AliasModel),
                          [('date', 'Date'), ('date_alias', "Date")])
class EnumFieldsTest(unittest.TestCase):
    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(ModelWithEnum)
        self.database.create_table(ModelWithEnumArray)

    def tearDown(self):
        self.database.drop_database()

    def test_insert_and_select(self):
        self.database.insert([
            ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
            ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
        ])
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, ModelWithEnum))
        self.assertEqual(len(results), 2)
        self.assertEqual(results[0].enum_field, Fruit.apple)
        self.assertEqual(results[1].enum_field, Fruit.orange)

    def test_ad_hoc_model(self):
        self.database.insert([
            ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
            ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
        ])
        query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
        results = list(self.database.select(query))
        self.assertEqual(len(results), 2)
        self.assertEqual(results[0].enum_field.name, Fruit.apple.name)
        self.assertEqual(results[0].enum_field.value, Fruit.apple.value)
        self.assertEqual(results[1].enum_field.name, Fruit.orange.name)
        self.assertEqual(results[1].enum_field.value, Fruit.orange.value)

    def test_conversion(self):
        self.assertEqual(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
        self.assertEqual(
            ModelWithEnum(enum_field='apple').enum_field, Fruit.apple)
        self.assertEqual(
            ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)

    def test_assignment_error(self):
        for value in (0, 17, 'pear', '', None, 99.9):
            with self.assertRaises(ValueError):
                ModelWithEnum(enum_field=value)

    def test_default_value(self):
        instance = ModelWithEnum()
        self.assertEqual(instance.enum_field, Fruit.apple)

    def test_enum_array(self):
        instance = ModelWithEnumArray(
            date_field='2016-08-30',
            enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
        self.database.insert([instance])
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, ModelWithEnumArray))
        self.assertEqual(len(results), 1)
        self.assertEqual(results[0].enum_array, instance.enum_array)
class CustomFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')

    def tearDown(self):
        self.database.drop_database()

    def test_boolean_field(self):
        # Create a model
        class TestModel(Model):
            i = Int16Field()
            f = BooleanField()
            engine = Memory()
        self.database.create_table(TestModel)
        # Check valid values
        for index, value in enumerate([1, '1', True, 0, '0', False]):
            rec = TestModel(i=index, f=value)
            self.database.insert([rec])
        self.assertEqual([rec.f for rec in TestModel.objects_in(self.database).order_by('i')],
                          [True, True, True, False, False, False])
        # Check invalid values
        for value in [None, 'zzz', -5, 7]:
            with self.assertRaises(ValueError):
                TestModel(i=1, f=value)

    def test_uuid_field(self):
        # Create a model
        class TestModel(Model):
            i = Int16Field()
            f = UUIDField()
            engine = Memory()
        self.database.create_table(TestModel)
        # Check valid values (all values are the same UUID)
        values = [
            '{12345678-1234-5678-1234-567812345678}',
            '12345678123456781234567812345678',
            'urn:uuid:12345678-1234-5678-1234-567812345678',
            '\x12\x34\x56\x78'*4,
            (0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678),
            0x12345678123456781234567812345678,
        ]
        for index, value in enumerate(values):
            rec = TestModel(i=index, f=value)
            self.database.insert([rec])
        for rec in TestModel.objects_in(self.database):
            self.assertEqual(rec.f, UUID(values[0]))
        # Check that ClickHouse encoding functions are supported
        for rec in self.database.select("SELECT i, UUIDNumToString(f) AS f FROM testmodel", TestModel):
            self.assertEqual(rec.f, UUID(values[0]))
        for rec in self.database.select("SELECT 1 as i, UUIDStringToNum('12345678-1234-5678-1234-567812345678') AS f", TestModel):
            self.assertEqual(rec.f, UUID(values[0]))
        # Check invalid values
        for value in [None, 'zzz', -1, '123']:
            with self.assertRaises(ValueError):
                TestModel(i=1, f=value)
class EnumFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(ModelWithEnum)
        self.database.create_table(ModelWithEnumArray)

    def tearDown(self):
        self.database.drop_database()

    def test_insert_and_select(self):
        self.database.insert([
            ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
            ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
        ])
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, ModelWithEnum))
        self.assertEqual(len(results), 2)
        self.assertEqual(results[0].enum_field, Fruit.apple)
        self.assertEqual(results[1].enum_field, Fruit.orange)

    def test_ad_hoc_model(self):
        self.database.insert([
            ModelWithEnum(date_field='2016-08-30', enum_field=Fruit.apple),
            ModelWithEnum(date_field='2016-08-31', enum_field=Fruit.orange)
        ])
        query = 'SELECT * from $db.modelwithenum ORDER BY date_field'
        results = list(self.database.select(query))
        self.assertEqual(len(results), 2)
        self.assertEqual(results[0].enum_field.name, Fruit.apple.name)
        self.assertEqual(results[0].enum_field.value, Fruit.apple.value)
        self.assertEqual(results[1].enum_field.name, Fruit.orange.name)
        self.assertEqual(results[1].enum_field.value, Fruit.orange.value)

    def test_conversion(self):
        self.assertEqual(ModelWithEnum(enum_field=3).enum_field, Fruit.orange)
        self.assertEqual(ModelWithEnum(enum_field='apple').enum_field, Fruit.apple)
        self.assertEqual(ModelWithEnum(enum_field=Fruit.banana).enum_field, Fruit.banana)

    def test_assignment_error(self):
        for value in (0, 17, 'pear', '', None, 99.9):
            with self.assertRaises(ValueError):
                ModelWithEnum(enum_field=value)

    def test_default_value(self):
        instance = ModelWithEnum()
        self.assertEqual(instance.enum_field, Fruit.apple)

    def test_enum_array(self):
        instance = ModelWithEnumArray(date_field='2016-08-30', enum_array=[Fruit.apple, Fruit.apple, Fruit.orange])
        self.database.insert([instance])
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, ModelWithEnumArray))
        self.assertEqual(len(results), 1)
        self.assertEqual(results[0].enum_array, instance.enum_array)
class DateFieldsTest(unittest.TestCase):
    def setUp(self):
        self.database = Database('test-db', log_statements=True)
        self.database.create_table(ModelWithDate)

    def tearDown(self):
        self.database.drop_database()

    def test_ad_hoc_model(self):
        self.database.insert([
            ModelWithDate(date_field='2016-08-30',
                          datetime_field='2016-08-30 03:50:00'),
            ModelWithDate(date_field='2016-08-31',
                          datetime_field='2016-08-31 01:30:00')
        ])

        # toStartOfHour returns DateTime('Asia/Yekaterinburg') in my case, so I test it here to
        query = 'SELECT toStartOfHour(datetime_field) as hour_start, * from $db.modelwithdate ORDER BY date_field'
        results = list(self.database.select(query))
        self.assertEqual(len(results), 2)
        self.assertEqual(results[0].date_field, datetime.date(2016, 8, 30))
        self.assertEqual(
            results[0].datetime_field,
            datetime.datetime(2016, 8, 30, 3, 50, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[0].hour_start,
            datetime.datetime(2016, 8, 30, 3, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(results[1].date_field, datetime.date(2016, 8, 31))
        self.assertEqual(
            results[1].datetime_field,
            datetime.datetime(2016, 8, 31, 1, 30, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[1].hour_start,
            datetime.datetime(2016, 8, 31, 1, 0, 0, tzinfo=pytz.UTC))
class DateFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(ModelWithDate)

    def tearDown(self):
        self.database.drop_database()

    def test_ad_hoc_model(self):
        self.database.insert([
            ModelWithDate(date_field='2016-08-30', datetime_field='2016-08-30 03:50:00'),
            ModelWithDate(date_field='2016-08-31', datetime_field='2016-08-31 01:30:00')
        ])

        # toStartOfHour returns DateTime('Asia/Yekaterinburg') in my case, so I test it here to
        query = 'SELECT toStartOfHour(datetime_field) as hour_start, * from $db.modelwithdate ORDER BY date_field'
        results = list(self.database.select(query))
        self.assertEqual(len(results), 2)
        self.assertEqual(results[0].date_field, datetime.date(2016, 8, 30))
        self.assertEqual(results[0].datetime_field, datetime.datetime(2016, 8, 30, 3, 50, 0, tzinfo=pytz.UTC))
        self.assertEqual(results[0].hour_start, datetime.datetime(2016, 8, 30, 3, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(results[1].date_field, datetime.date(2016, 8, 31))
        self.assertEqual(results[1].datetime_field, datetime.datetime(2016, 8, 31, 1, 30, 0, tzinfo=pytz.UTC))
        self.assertEqual(results[1].hour_start, datetime.datetime(2016, 8, 31, 1, 0, 0, tzinfo=pytz.UTC))
class ArrayFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(ModelWithArrays)

    def tearDown(self):
        self.database.drop_database()

    def test_insert_and_select(self):
        instance = ModelWithArrays(
            date_field='2016-08-30',
            arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'],
            arr_date=['2010-01-01'],
        )
        self.database.insert([instance])
        query = 'SELECT * from $db.modelwitharrays ORDER BY date_field'
        for model_cls in (ModelWithArrays, None):
            results = list(self.database.select(query, model_cls))
            self.assertEqual(len(results), 1)
            self.assertEqual(results[0].arr_str, instance.arr_str)
            self.assertEqual(results[0].arr_int, instance.arr_int)
            self.assertEqual(results[0].arr_date, instance.arr_date)

    def test_conversion(self):
        instance = ModelWithArrays(
            arr_int=('1', '2', '3'),
            arr_date=['2010-01-01']
        )
        self.assertEqual(instance.arr_str, [])
        self.assertEqual(instance.arr_int, [1, 2, 3])
        self.assertEqual(instance.arr_date, [date(2010, 1, 1)])

    def test_assignment_error(self):
        instance = ModelWithArrays()
        for value in (7, 'x', [date.today()], ['aaa'], [None]):
            with self.assertRaises(ValueError):
                instance.arr_int = value

    def test_parse_array(self):
        from infi.clickhouse_orm.utils import parse_array, unescape
        self.assertEqual(parse_array("[]"), [])
        self.assertEqual(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"])
        self.assertEqual(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"])
        self.assertEqual(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"])
        for s in ("",
                  "[",
                  "]",
                  "[1, 2",
                  "3, 4]",
                  "['aaa', 'aaa]"):
            with self.assertRaises(ValueError):
                parse_array(s)

    def test_invalid_inner_field(self):
        for x in (DateField, None, "", ArrayField(Int32Field())):
            with self.assertRaises(AssertionError):
                ArrayField(x)
Exemple #9
0
class ArrayFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db', log_statements=True)
        self.database.create_table(ModelWithArrays)

    def tearDown(self):
        self.database.drop_database()

    def test_insert_and_select(self):
        instance = ModelWithArrays(
            date_field='2016-08-30',
            arr_str=['goodbye,', 'cruel', 'world', 'special chars: ,"\\\'` \n\t\\[]'],
            arr_date=['2010-01-01'],
        )
        self.database.insert([instance])
        query = 'SELECT * from $db.modelwitharrays ORDER BY date_field'
        for model_cls in (ModelWithArrays, None):
            results = list(self.database.select(query, model_cls))
            self.assertEqual(len(results), 1)
            self.assertEqual(results[0].arr_str, instance.arr_str)
            self.assertEqual(results[0].arr_int, instance.arr_int)
            self.assertEqual(results[0].arr_date, instance.arr_date)

    def test_conversion(self):
        instance = ModelWithArrays(
            arr_int=('1', '2', '3'),
            arr_date=['2010-01-01']
        )
        self.assertEqual(instance.arr_str, [])
        self.assertEqual(instance.arr_int, [1, 2, 3])
        self.assertEqual(instance.arr_date, [date(2010, 1, 1)])

    def test_assignment_error(self):
        instance = ModelWithArrays()
        for value in (7, 'x', [date.today()], ['aaa'], [None]):
            with self.assertRaises(ValueError):
                instance.arr_int = value

    def test_parse_array(self):
        from infi.clickhouse_orm.utils import parse_array, unescape
        self.assertEqual(parse_array("[]"), [])
        self.assertEqual(parse_array("[1, 2, 395, -44]"), ["1", "2", "395", "-44"])
        self.assertEqual(parse_array("['big','mouse','','!']"), ["big", "mouse", "", "!"])
        self.assertEqual(parse_array(unescape("['\\r\\n\\0\\t\\b']")), ["\r\n\0\t\b"])
        for s in ("",
                  "[",
                  "]",
                  "[1, 2",
                  "3, 4]",
                  "['aaa', 'aaa]"):
            with self.assertRaises(ValueError):
                parse_array(s)

    def test_invalid_inner_field(self):
        for x in (DateField, None, "", ArrayField(Int32Field())):
            with self.assertRaises(AssertionError):
                ArrayField(x)
class AliasFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db', log_statements=True)
        self.database.create_table(ModelWithAliasFields)

    def tearDown(self):
        self.database.drop_database()

    def test_insert_and_select(self):
        instance = ModelWithAliasFields(
            date_field='2016-08-30',
            int_field=-10,
            str_field='TEST'
        )
        self.database.insert([instance])
        # We can't select * from table, as it doesn't select materialized and alias fields
        query = 'SELECT date_field, int_field, str_field, alias_int, alias_date, alias_str, alias_func' \
                ' FROM $db.%s ORDER BY alias_date' % ModelWithAliasFields.table_name()
        for model_cls in (ModelWithAliasFields, None):
            results = list(self.database.select(query, model_cls))
            self.assertEqual(len(results), 1)
            self.assertEqual(results[0].date_field, instance.date_field)
            self.assertEqual(results[0].int_field, instance.int_field)
            self.assertEqual(results[0].str_field, instance.str_field)
            self.assertEqual(results[0].alias_int, instance.int_field)
            self.assertEqual(results[0].alias_str, instance.str_field)
            self.assertEqual(results[0].alias_date, instance.date_field)
            self.assertEqual(results[0].alias_func, 201608)

    def test_assignment_error(self):
        # I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
        instance = ModelWithAliasFields()
        for value in ('x', [date.today()], ['aaa'], [None]):
            with self.assertRaises(ValueError):
                instance.alias_date = value

    def test_wrong_field(self):
        with self.assertRaises(AssertionError):
            StringField(alias=123)

    def test_duplicate_default(self):
        with self.assertRaises(AssertionError):
            StringField(alias='str_field', default='with default')

        with self.assertRaises(AssertionError):
            StringField(alias='str_field', materialized='str_field')

    def test_default_value(self):
        instance = ModelWithAliasFields()
        self.assertEqual(instance.alias_str, NO_VALUE)
        # Check that NO_VALUE can be assigned to a field
        instance.str_field = NO_VALUE
        # Check that NO_VALUE can be assigned when creating a new instance
        instance2 = ModelWithAliasFields(**instance.to_dict())
class FixedStringFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(FixedStringModel)

    def tearDown(self):
        self.database.drop_database()

    def _insert_sample_data(self):
        self.database.insert([
            FixedStringModel(date_field='2016-08-30', fstr_field=''),
            FixedStringModel(date_field='2016-08-30'),
            FixedStringModel(date_field='2016-08-31', fstr_field='foo'),
            FixedStringModel(date_field='2016-08-31', fstr_field=u'לילה')
        ])

    def _assert_sample_data(self, results):
        self.assertEqual(len(results), 4)
        self.assertEqual(results[0].fstr_field, '')
        self.assertEqual(results[1].fstr_field, 'ABCDEFGHIJK')
        self.assertEqual(results[2].fstr_field, 'foo')
        self.assertEqual(results[3].fstr_field, u'לילה')

    def test_insert_and_select(self):
        self._insert_sample_data()
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, FixedStringModel))
        self._assert_sample_data(results)

    def test_ad_hoc_model(self):
        self._insert_sample_data()
        query = 'SELECT * from $db.fixedstringmodel ORDER BY date_field'
        results = list(self.database.select(query))
        self._assert_sample_data(results)

    def test_assignment_error(self):
        for value in (17, 'this is too long', u'זה ארוך', None, 99.9):
            with self.assertRaises(ValueError):
                FixedStringModel(fstr_field=value)
Exemple #12
0
class FixedStringFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(FixedStringModel)

    def tearDown(self):
        self.database.drop_database()

    def _insert_sample_data(self):
        self.database.insert([
            FixedStringModel(date_field='2016-08-30', fstr_field=''),
            FixedStringModel(date_field='2016-08-30'),
            FixedStringModel(date_field='2016-08-31', fstr_field='foo'),
            FixedStringModel(date_field='2016-08-31', fstr_field=u'לילה')
        ])

    def _assert_sample_data(self, results):
        self.assertEquals(len(results), 4)
        self.assertEquals(results[0].fstr_field, '')
        self.assertEquals(results[1].fstr_field, 'ABCDEFGHIJK')
        self.assertEquals(results[2].fstr_field, 'foo')
        self.assertEquals(results[3].fstr_field, u'לילה')

    def test_insert_and_select(self):
        self._insert_sample_data()
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, FixedStringModel))
        self._assert_sample_data(results)

    def test_ad_hoc_model(self):
        self._insert_sample_data()
        query = 'SELECT * from $db.fixedstringmodel ORDER BY date_field'
        results = list(self.database.select(query))
        self._assert_sample_data(results)

    def test_assignment_error(self):
        for value in (17, 'this is too long', u'זה ארוך', None, 99.9):
            with self.assertRaises(ValueError):
                FixedStringModel(fstr_field=value)
Exemple #13
0
class MaterializedFieldsTest(unittest.TestCase):
    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(ModelWithMaterializedFields)

    def tearDown(self):
        self.database.drop_database()

    def test_insert_and_select(self):
        instance = ModelWithMaterializedFields(
            date_time_field='2016-08-30 11:00:00',
            int_field=-10,
            str_field='TEST')
        self.database.insert([instance])
        # We can't select * from table, as it doesn't select materialized and alias fields
        query = 'SELECT date_time_field, int_field, str_field, mat_int, mat_date, mat_str' \
                ' FROM $db.%s ORDER BY mat_date' % ModelWithMaterializedFields.table_name()
        for model_cls in (ModelWithMaterializedFields, None):
            results = list(self.database.select(query, model_cls))
            self.assertEqual(len(results), 1)
            self.assertEqual(results[0].date_time_field,
                             instance.date_time_field)
            self.assertEqual(results[0].int_field, instance.int_field)
            self.assertEqual(results[0].str_field, instance.str_field)
            self.assertEqual(results[0].mat_int, abs(instance.int_field))
            self.assertEqual(results[0].mat_str, instance.str_field.lower())
            self.assertEqual(results[0].mat_date,
                             instance.date_time_field.date())

    def test_assignment_error(self):
        # I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
        instance = ModelWithMaterializedFields()
        for value in ('x', [date.today()], ['aaa'], [None]):
            with self.assertRaises(ValueError):
                instance.mat_date = value

    def test_wrong_field(self):
        with self.assertRaises(AssertionError):
            StringField(materialized=123)

    def test_duplicate_default(self):
        with self.assertRaises(AssertionError):
            StringField(materialized='str_field', default='with default')

        with self.assertRaises(AssertionError):
            StringField(materialized='str_field', alias='str_field')
class MaterializedFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(ModelWithAliasFields)

    def tearDown(self):
        self.database.drop_database()

    def test_insert_and_select(self):
        instance = ModelWithAliasFields(
            date_field='2016-08-30',
            int_field=-10,
            str_field='TEST'
        )
        self.database.insert([instance])
        # We can't select * from table, as it doesn't select materialized and alias fields
        query = 'SELECT date_field, int_field, str_field, alias_int, alias_date, alias_str' \
                ' FROM $db.%s ORDER BY alias_date' % ModelWithAliasFields.table_name()
        for model_cls in (ModelWithAliasFields, None):
            results = list(self.database.select(query, model_cls))
            self.assertEqual(len(results), 1)
            self.assertEqual(results[0].date_field, instance.date_field)
            self.assertEqual(results[0].int_field, instance.int_field)
            self.assertEqual(results[0].str_field, instance.str_field)
            self.assertEqual(results[0].alias_int, instance.int_field)
            self.assertEqual(results[0].alias_str, instance.str_field)
            self.assertEqual(results[0].alias_date, instance.date_field)

    def test_assignment_error(self):
        # I can't prevent assigning at all, in case db.select statements with model provided sets model fields.
        instance = ModelWithAliasFields()
        for value in ('x', [date.today()], ['aaa'], [None]):
            with self.assertRaises(ValueError):
                instance.alias_date = value

    def test_wrong_field(self):
        with self.assertRaises(AssertionError):
            StringField(alias=123)

    def test_duplicate_default(self):
        with self.assertRaises(AssertionError):
            StringField(alias='str_field', default='with default')

        with self.assertRaises(AssertionError):
            StringField(alias='str_field', materialized='str_field')
Exemple #15
0
class ReadonlyTestCase(TestCaseWithData):
    def _test_readonly_db(self, username):
        self._insert_and_check(self._sample_data(), len(data))
        orig_database = self.database
        try:
            self.database = Database(orig_database.db_name,
                                     username=username,
                                     readonly=True)
            with self.assertRaises(DatabaseException):
                self._insert_and_check(self._sample_data(), len(data))
            self.assertEquals(self.database.count(Person), 100)
            list(self.database.select('SELECT * from $table', Person))
            with self.assertRaises(DatabaseException):
                self.database.drop_table(Person)
            with self.assertRaises(DatabaseException):
                self.database.drop_database()
        except DatabaseException as e:
            if 'Unknown user' in six.text_type(e):
                raise unittest.SkipTest('Database user "%s" is not defined' %
                                        username)
            else:
                raise
        finally:
            self.database = orig_database

    def test_readonly_db_with_default_user(self):
        self._test_readonly_db('default')

    def test_readonly_db_with_readonly_user(self):
        self._test_readonly_db('readonly')

    def test_insert_readonly(self):
        m = ReadOnlyModel(name='readonly')
        with self.assertRaises(DatabaseException):
            self.database.insert([m])

    def test_create_readonly_table(self):
        with self.assertRaises(DatabaseException):
            self.database.create_table(ReadOnlyModel)

    def test_drop_readonly_table(self):
        with self.assertRaises(DatabaseException):
            self.database.drop_table(ReadOnlyModel)
class DecimalFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.add_setting('allow_experimental_decimal_type', 1)
        try:
            self.database.create_table(DecimalModel)
        except ServerError as e:
            if 'Unknown setting' in e.message:
                # This ClickHouse version does not support decimals yet
                raise unittest.SkipTest(e.message)
            else:
                raise

    def tearDown(self):
        self.database.drop_database()

    def _insert_sample_data(self):
        self.database.insert([
            DecimalModel(date_field='2016-08-20'),
            DecimalModel(date_field='2016-08-21', dec=Decimal('1.234')),
            DecimalModel(date_field='2016-08-22', dec32=Decimal('12342.2345')),
            DecimalModel(date_field='2016-08-23', dec64=Decimal('12342.23456')),
            DecimalModel(date_field='2016-08-24', dec128=Decimal('-4545456612342.234567')),
        ])

    def _assert_sample_data(self, results):
        self.assertEqual(len(results), 5)
        self.assertEqual(results[0].dec, Decimal(0))
        self.assertEqual(results[0].dec32, Decimal(17))
        self.assertEqual(results[1].dec, Decimal('1.234'))
        self.assertEqual(results[2].dec32, Decimal('12342.2345'))
        self.assertEqual(results[3].dec64, Decimal('12342.23456'))
        self.assertEqual(results[4].dec128, Decimal('-4545456612342.234567'))

    def test_insert_and_select(self):
        self._insert_sample_data()
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, DecimalModel))
        self._assert_sample_data(results)

    def test_ad_hoc_model(self):
        self._insert_sample_data()
        query = 'SELECT * from decimalmodel ORDER BY date_field'
        results = list(self.database.select(query))
        self._assert_sample_data(results)

    def test_rounding(self):
        d = Decimal('11111.2340000000000000001')
        self.database.insert([DecimalModel(date_field='2016-08-20', dec=d, dec32=d, dec64=d, dec128=d)])
        m = DecimalModel.objects_in(self.database)[0]
        for val in (m.dec, m.dec32, m.dec64, m.dec128):
            self.assertEqual(val, Decimal('11111.234'))

    def test_assignment_ok(self):
        for value in (True, False, 17, 3.14, '20.5', Decimal('20.5')):
            DecimalModel(dec=value)

    def test_assignment_error(self):
        for value in ('abc', u'זה ארוך', None, float('NaN'), Decimal('-Infinity')):
            with self.assertRaises(ValueError):
                DecimalModel(dec=value)

    def test_aggregation(self):
        self._insert_sample_data()
        result = DecimalModel.objects_in(self.database).aggregate(m='min(dec)', n='max(dec)')
        self.assertEqual(result[0].m, Decimal(0))
        self.assertEqual(result[0].n, Decimal('1.234'))

    def test_precision_and_scale(self):
        # Go over all valid combinations
        for precision in range(1, 39):
            for scale in range(0, precision + 1):
                f = DecimalField(precision, scale)
        # Some invalid combinations
        for precision, scale in [(0, 0), (-1, 7), (7, -1), (39, 5), (20, 21)]:
            with self.assertRaises(AssertionError):
                f = DecimalField(precision, scale)

    def test_min_max(self):
        # In range
        f = DecimalField(3, 1)
        f.validate(f.to_python('99.9', None))
        f.validate(f.to_python('-99.9', None))
        # In range after rounding
        f.validate(f.to_python('99.94', None))
        f.validate(f.to_python('-99.94', None))
        # Out of range
        with self.assertRaises(ValueError):
            f.validate(f.to_python('99.99', None))
        with self.assertRaises(ValueError):
            f.validate(f.to_python('-99.99', None))
        # In range
        f = Decimal32Field(4)
        f.validate(f.to_python('99999.9999', None))
        f.validate(f.to_python('-99999.9999', None))
        # In range after rounding
        f.validate(f.to_python('99999.99994', None))
        f.validate(f.to_python('-99999.99994', None))
        # Out of range
        with self.assertRaises(ValueError):
            f.validate(f.to_python('100000', None))
        with self.assertRaises(ValueError):
            f.validate(f.to_python('-100000', None))
class NullableFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(ModelWithNullable)

    def tearDown(self):
        self.database.drop_database()

    def test_nullable_datetime_field(self):
        f = NullableField(DateTimeField())
        epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
        # Valid values
        for value in (date(1970, 1, 1),
                      datetime(1970, 1, 1),
                      epoch,
                      epoch.astimezone(pytz.timezone('US/Eastern')),
                      epoch.astimezone(pytz.timezone('Asia/Jerusalem')),
                      '1970-01-01 00:00:00',
                      '1970-01-17 00:00:17',
                      '0000-00-00 00:00:00',
                      0,
                      '\\N'):
            dt = f.to_python(value, pytz.utc)
            if value == '\\N':
                self.assertIsNone(dt)
            else:
                self.assertEqual(dt.tzinfo, pytz.utc)
            # Verify that conversion to and from db string does not change value
            dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
            self.assertEqual(dt, dt2)
        # Invalid values
        for value in ('nope', '21/7/1999', 0.5):
            with self.assertRaises(ValueError):
                f.to_python(value, pytz.utc)

    def test_nullable_uint8_field(self):
        f = NullableField(UInt8Field())
        # Valid values
        for value in (17, '17', 17.0, '\\N'):
            python_value = f.to_python(value, pytz.utc)
            if value == '\\N':
                self.assertIsNone(python_value)
                self.assertEqual(value, f.to_db_string(python_value))
            else:
                self.assertEqual(python_value, 17)

        # Invalid values
        for value in ('nope', date.today()):
            with self.assertRaises(ValueError):
                f.to_python(value, pytz.utc)

    def test_nullable_string_field(self):
        f = NullableField(StringField())
        # Valid values
        for value in ('\\\\N', 'N', 'some text', '\\N'):
            python_value = f.to_python(value, pytz.utc)
            if value == '\\N':
                self.assertIsNone(python_value)
                self.assertEqual(value, f.to_db_string(python_value))
            else:
                self.assertEqual(python_value, value)

    def test_isinstance(self):
        for field in (StringField, UInt8Field, Float32Field, DateTimeField):
            f = NullableField(field())
            self.assertTrue(f.isinstance(field))
            self.assertTrue(f.isinstance(NullableField))
        for field in (Int8Field, Int16Field, Int32Field, Int64Field, UInt8Field, UInt16Field, UInt32Field, UInt64Field):
            f = NullableField(field())
            self.assertTrue(f.isinstance(BaseIntField))
        for field in (Float32Field, Float64Field):
            f = NullableField(field())
            self.assertTrue(f.isinstance(BaseFloatField))
        f = NullableField(NullableField(UInt32Field()))
        self.assertTrue(f.isinstance(BaseIntField))
        self.assertTrue(f.isinstance(NullableField))
        self.assertFalse(f.isinstance(BaseFloatField))

    def _insert_sample_data(self):
        dt = date(1970, 1, 1)
        self.database.insert([
            ModelWithNullable(date_field='2016-08-30', null_str='', null_int=42, null_date=dt),
            ModelWithNullable(date_field='2016-08-30', null_str='nothing', null_int=None, null_date=None),
            ModelWithNullable(date_field='2016-08-31', null_str=None, null_int=42, null_date=dt),
            ModelWithNullable(date_field='2016-08-31', null_str=None, null_int=None, null_date=None)
        ])

    def _assert_sample_data(self, results):
        dt = date(1970, 1, 1)
        self.assertEqual(len(results), 4)
        self.assertIsNone(results[0].null_str)
        self.assertEqual(results[0].null_int, 42)
        self.assertEqual(results[0].null_date, dt)
        self.assertIsNone(results[1].null_date)
        self.assertEqual(results[1].null_str, 'nothing')
        self.assertIsNone(results[1].null_date)
        self.assertIsNone(results[2].null_str)
        self.assertEqual(results[2].null_date, dt)
        self.assertEqual(results[2].null_int, 42)
        self.assertIsNone(results[3].null_int)
        self.assertIsNone(results[3].null_str)
        self.assertIsNone(results[3].null_date)

    def test_insert_and_select(self):
        self._insert_sample_data()
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, ModelWithNullable))
        self._assert_sample_data(results)

    def test_ad_hoc_model(self):
        self._insert_sample_data()
        query = 'SELECT * from $db.modelwithnullable ORDER BY date_field'
        results = list(self.database.select(query))
        self._assert_sample_data(results)
class DatabaseTestCase(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(Person)

    def tearDown(self):
        self.database.drop_table(Person)
        self.database.drop_database()

    def _insert_and_check(self, data, count):
        self.database.insert(data)
        self.assertEquals(count, self.database.count(Person))

    def test_insert__generator(self):
        self._insert_and_check(self._sample_data(), len(data))

    def test_insert__list(self):
        self._insert_and_check(list(self._sample_data()), len(data))

    def test_insert__iterator(self):
        self._insert_and_check(iter(self._sample_data()), len(data))

    def test_insert__empty(self):
        self._insert_and_check([], 0)

    def test_count(self):
        self.database.insert(self._sample_data())
        self.assertEquals(self.database.count(Person), 100)
        self.assertEquals(self.database.count(Person, "first_name = 'Courtney'"), 2)
        self.assertEquals(self.database.count(Person, "birthday > '2000-01-01'"), 22)
        self.assertEquals(self.database.count(Person, "birthday < '1970-03-01'"), 0)

    def test_select(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query, Person))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 1.72)
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 1.70)

    def test_select_partial_fields(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT first_name, last_name FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query, Person))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 0) # default value
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 0) # default value

    def test_select_ad_hoc_model(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].__class__.__name__, 'AdHocModel')
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 1.72)
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 1.70)

    def test_pagination(self):
        self._insert_and_check(self._sample_data(), len(data))
        # Try different page sizes
        for page_size in (1, 2, 7, 10, 30, 100, 150):
            # Iterate over pages and collect all intances
            page_num = 1
            instances = set()
            while True:
                page = self.database.paginate(Person, 'first_name, last_name', page_num, page_size)
                self.assertEquals(page.number_of_objects, len(data))
                self.assertGreater(page.pages_total, 0)
                [instances.add(obj.to_tsv()) for obj in page.objects]
                if page.pages_total == page_num:
                    break
                page_num += 1
            # Verify that all instances were returned
            self.assertEquals(len(instances), len(data))

    def test_pagination_last_page(self):
        self._insert_and_check(self._sample_data(), len(data))
        # Try different page sizes
        for page_size in (1, 2, 7, 10, 30, 100, 150):
            # Ask for the last page in two different ways and verify equality
            page_a = self.database.paginate(Person, 'first_name, last_name', -1, page_size)
            page_b = self.database.paginate(Person, 'first_name, last_name', page_a.pages_total, page_size)
            self.assertEquals(page_a[1:], page_b[1:])
            self.assertEquals([obj.to_tsv() for obj in page_a.objects], 
                              [obj.to_tsv() for obj in page_b.objects])

    def test_pagination_invalid_page(self):
        self._insert_and_check(self._sample_data(), len(data))
        for page_num in (0, -2, -100):
            with self.assertRaises(ValueError):
                self.database.paginate(Person, 'first_name, last_name', page_num, 100)

    def test_special_chars(self):
        s = u'אבגד \\\'"`,.;éåäöšž\n\t\0\b\r'
        p = Person(first_name=s)
        self.database.insert([p])
        p = list(self.database.select("SELECT * from $table", Person))[0]
        self.assertEquals(p.first_name, s)

    def _sample_data(self):
        for entry in data:
            yield Person(**entry)
Exemple #19
0

db = Database('adp_ch', db_url='http://s2.meta.vmc.loc:8123')
db.create_table(Log)
# db.drop_database()
# cat /Users/arturgspb/PycharmProjects/test/uploader/logs/collect_automator_logs__adp_ch.log_2016_08_26__1472226423_rotated.csv |  curl --data-binary @- 'http://s2.meta.vmc.loc:8123/?query=INSERT INTO adp_ch.log FORMAT CSV';

# cat /Users/arturgspb/PycharmProjects/test/uploader/logs/collect_automator_logs__adp_ch.log_2016_08_26__1472226423_rotated.csv | POST 'http://s2.meta.vmc.loc:8123/?query=INSERT INTO adp_ch FORMAT CSVWithNames';

#echo ' curl 'http://*****:*****@- -H 'Content-Encoding: gzip' 'http://s2.meta.vmc.loc:8123/?query=INSERT%20INTO%20adp_ch.log%20FORMAT%20CSVWithNames';

# meta = MetaApp()
# meta.log.info('Start')
#
# for idx in xrange(1000):
#     ops = []
#     meta.log.info('Do insert')
#     for idx in xrange(100000):
#         dan = Person(first_name='Dan ' + str(random.randint(1, 1000000)), last_name='Schwartz ' + str(uuid.uuid1()), height=random.randint(1, 1000000))
#         ops.append(dan)
#     db.insert(ops)
# meta.log.info('Stop')
# # db.insert([dan, suzy])
# #
# # # SELECT
for person in db.select("SELECT COUNT(*) as cnt FROM adp_ch.log"):
    print(person.__dict__)
#
#
# # params = {
from infi.clickhouse_orm import models, fields, engines
from infi.clickhouse_orm.database import Database


class Test(models.Model):

    id = fields.Int64Field()
    a = fields.StringField()
    b = fields.StringField()
    c = fields.StringField()
    d = fields.StringField()

    engine = engines.MergeTree('id', ('a', 'b', 'c', 'd'))

db_url = 'http://web1:8123'
db_name = 'csv_parser_db'
db_username = '******'
db_password = '******'
db = Database(db_name=db_name, db_url=db_url, username=db_username, password=db_password)
db.create_table(Test)

# Insert some data
db.insert([
    Test(id=i, a=str(i), b=str(i), c=str(i), d=str(i)) for i in xrange(10, 15)
])

# Read data
for row in db.select("SELECT * FROM {}.test".format(db_name), model_class=Test):
    print row.id, row.a, row.b, row.c, row.d
class MigrationsTestCase(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.drop_table(MigrationHistory)

    def tearDown(self):
        self.database.drop_database()

    def tableExists(self, model_class):
        query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
        return next(self.database.select(query)).result == 1

    def getTableFields(self, model_class):
        query = "DESC `%s`.`%s`" % (self.database.db_name, model_class.table_name())
        return [(row.name, row.type) for row in self.database.select(query)]

    def test_migrations(self):
        # Creation and deletion of table
        self.database.migrate('tests.sample_migrations', 1)
        self.assertTrue(self.tableExists(Model1))
        self.database.migrate('tests.sample_migrations', 2)
        self.assertFalse(self.tableExists(Model1))
        self.database.migrate('tests.sample_migrations', 3)
        self.assertTrue(self.tableExists(Model1))
        # Adding, removing and altering simple fields
        self.assertEqual(self.getTableFields(Model1), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
        self.database.migrate('tests.sample_migrations', 4)
        self.assertEqual(self.getTableFields(Model2), [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'), ('f2', 'String'), ('f4', 'String'), ('f5', 'Array(UInt64)')])
        self.database.migrate('tests.sample_migrations', 5)
        self.assertEqual(self.getTableFields(Model3), [('date', 'Date'), ('f1', 'Int64'), ('f3', 'Float64'), ('f4', 'String')])
        # Altering enum fields
        self.database.migrate('tests.sample_migrations', 6)
        self.assertTrue(self.tableExists(EnumModel1))
        self.assertEqual(self.getTableFields(EnumModel1),
                          [('date', 'Date'), ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")])
        self.database.migrate('tests.sample_migrations', 7)
        self.assertTrue(self.tableExists(EnumModel1))
        self.assertEqual(self.getTableFields(EnumModel2),
                          [('date', 'Date'), ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")])
        # Materialized fields and alias fields
        self.database.migrate('tests.sample_migrations', 8)
        self.assertTrue(self.tableExists(MaterializedModel))
        self.assertEqual(self.getTableFields(MaterializedModel),
                          [('date_time', "DateTime"), ('date', 'Date')])
        self.database.migrate('tests.sample_migrations', 9)
        self.assertTrue(self.tableExists(AliasModel))
        self.assertEqual(self.getTableFields(AliasModel),
                          [('date', 'Date'), ('date_alias', "Date")])
        # Buffer models creation and alteration
        self.database.migrate('tests.sample_migrations', 10)
        self.assertTrue(self.tableExists(Model4))
        self.assertTrue(self.tableExists(Model4Buffer))
        self.assertEqual(self.getTableFields(Model4), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
        self.assertEqual(self.getTableFields(Model4Buffer), [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
        self.database.migrate('tests.sample_migrations', 11)
        self.assertEqual(self.getTableFields(Model4), [('date', 'Date'), ('f3', 'DateTime'), ('f2', 'String')])
        self.assertEqual(self.getTableFields(Model4Buffer), [('date', 'Date'), ('f3', 'DateTime'), ('f2', 'String')])

        self.database.migrate('tests.sample_migrations', 12)
        self.assertEqual(self.database.count(Model3), 3)
        data = [item.f1 for item in self.database.select('SELECT f1 FROM $table ORDER BY f1', model_class=Model3)]
        self.assertListEqual(data, [1, 2, 3])

        self.database.migrate('tests.sample_migrations', 13)
        self.assertEqual(self.database.count(Model3), 4)
        data = [item.f1 for item in self.database.select('SELECT f1 FROM $table ORDER BY f1', model_class=Model3)]
        self.assertListEqual(data, [1, 2, 3, 4])

        self.database.migrate('tests.sample_migrations', 14)
        self.assertTrue(self.tableExists(MaterializedModel1))
        self.assertEqual(self.getTableFields(MaterializedModel1),
                          [('date_time', 'DateTime'), ('int_field', 'Int8'), ('date', 'Date'), ('int_field_plus_one', 'Int8')])
        self.assertTrue(self.tableExists(AliasModel1))
        self.assertEqual(self.getTableFields(AliasModel1),
                          [('date', 'Date'), ('int_field', 'Int8'), ('date_alias', 'Date'), ('int_field_plus_one', 'Int8')])
class Patterning(object):
    def __init__(self):
        self.p = argparse.ArgumentParser(
            description="Patterning is a program that allows to perform de-novo protein design " \
                        "by searching for possible proteins or parts of protein which match a given 3D structure or a possible 1D secondary structure with possible binary patterning fingerprint")

    def setup(self):
        self.db = None
        self.common_length = 0.0
        self.p.add_argument(
            "-w",
            "--workdir",
            default=".",
            help=
            "The working directory for this script to save and load temporary. By default it "
            "takes the current working directory. "
            "memory files required to function properly.")
        self.p.add_argument(
            "-q",
            "--pdbs",
            type=str,
            default=None,
            help="Local repository containing PDB Files to read from")
        self.p.add_argument("-d",
                            "--database",
                            default="propensities",
                            help="Database Name to use")
        self.p.add_argument("-t",
                            "--host",
                            default="localhost",
                            help="Database Host to use. defaults to localhost")
        self.p.add_argument("-r",
                            "--port",
                            default=8123,
                            help="Database Default port to use")
        self.p.add_argument("-u",
                            "--username",
                            help="Username of the database to use")
        self.p.add_argument("-o",
                            "--password",
                            help="Database password to use to login")
        self.p.add_argument(
            "-s",
            "--source",
            type=str,
            default=None,
            help=
            "Source PDB which contains the 3D structure of a protein or part of a protein (Fragment), that you want to search for similars in the database."
        )
        self.p.add_argument("-p",
                            "--pattern",
                            type=str,
                            default=None,
                            help="Binary Pattern of a fragment")
        self.p.add_argument(
            "-e",
            "--structure",
            type=str,
            default=None,
            help=
            "Possible Secondary structure of the fragment/protein, multiples of (H,B,E,G,I,T,S,-)."
        )
        self.p.add_argument(
            "-m",
            "--rmsd",
            type=float,
            default=-1,
            help=
            "Return matched proteins with this RMSD value only and exclude any others."
        )
        self.p.add_argument(
            "-l",
            "--limit",
            type=int,
            default=10,
            help="Total Number of results to include. Defaults to 10.")
        self.p.add_argument(
            "-x",
            "--start",
            type=int,
            default=-1,
            help="When searching by a whole protein containing a "
            "specific scaffold. This parameter denotes the "
            "location of the first residue of the fragment.")
        self.p.add_argument(
            "-y",
            "--end",
            type=int,
            default=-1,
            help="When searching by a whole protein containing a "
            "specific scaffold. This parameter denotes the "
            "location of the last residue of the fragment.")
        self.p.add_argument(
            "-a",
            "--chain",
            type=str,
            default=None,
            help=
            "Chain Identifier if your start and end are relative to a particular chain."
        )
        self.p.add_argument('-c',
                            '--cutoff',
                            default=23.9,
                            type=float,
                            help="Cutoff value to use to mark a residue "
                            "Solvent-accessible surface area as polar or "
                            "hydrophobic, if the SASA of a residue equal "
                            "to or higher than this value it will be "
                            "considered polar otherwise it will be "
                            "marked as hydrophobic")
        self.p.add_argument("-f",
                            "--fuzzy",
                            nargs="?",
                            const=True,
                            default=False,
                            help="Perform Fuzzy Search. Allow "
                            "matching similar but not "
                            "identical results.")
        self.p.add_argument("-z",
                            "--fuzzylevel",
                            type=float,
                            default=0.90,
                            help="Include only results with equal or "
                            "higher value than the following "
                            "fuzziness level .Defaults to 0.90")
        self.p.add_argument(
            "-v",
            "--distance",
            type=int,
            default=1,
            help="Possible Lovenstein distance for fuzzy string search")
        self.p.add_argument(
            "-j",
            "--deletions",
            type=int,
            default=0,
            help=
            "Number of allowed string deletions when fuzzy string search is enabled. Defaults to Zero."
        )
        self.p.add_argument(
            "-n",
            "--insertions",
            type=int,
            default=0,
            help=
            "Number of allowed string insertions when fuzzy string search is enabled. Defaults to Zero."
        )
        self.p.add_argument(
            "-k",
            "--substitutions",
            type=int,
            default=0,
            help=
            "Number of allowed string substitutions when fuzzy string search is enabled. Defaults to Zero."
        )

        if len(sys.argv) <= 1:
            self.p.print_help()
            return False
        self.args = self.p.parse_args()
        return True

    def connect(self):
        db_url = "http://{0}:{1}".format(self.args.host or "localhost",
                                         self.args.port or 8123)
        self.db = Database(db_name=self.args.database,
                           db_url=db_url,
                           username=self.args.username,
                           password=self.args.password)
        return True

    def print_results(self, headers, rows):
        limit = self.args.limit if self.args.limit < len(rows) else len(rows)
        data = [headers
                ] + [[getattr(row, x) for x in headers if hasattr(row, x)]
                     for row in rows]
        table = AsciiTable(data)
        table.title = "Possible Matche(s)"
        table.inner_row_border = True
        print(table.table)
        output = 'Total : {0}'.format(len(data) - 1)
        print(output)

    def search_by_common_pattern(self, pattern):
        if not self.db:
            print("Couldn't connect to the database.")
            self.p.print_help()
            return
        if self.args.fuzzy:
            query = "select p.* , ngramSearch(enhanced,'{0}') as ngram from proteins as p where ngram > {2} order by ngram DESC limit {3}".format(
                pattern, pattern, self.args.fuzzylevel, self.args.limit)
        else:
            query = "select p.* , position(enhanced,'{0}') as pos from proteins as p where position(enhanced,'{1}') > 0 limit {2}".format(
                pattern, pattern, self.args.limit)
        rows = []
        headers = ["protein_id"]
        if not self.args.fuzzy:
            headers += ["pos"]
        else:
            headers += ["ngram"]
        for row in self.db.select(query):

            rows.append(row)
        # if len(rows) > 0:
        #     self.print_results(headers, rows)
        return rows, headers

    def get_secondary_structure_details(self, name, pdb_file, aa_only=False):
        parser = PDBParser()
        structure = parser.get_structure(name, pdb_file)
        dssp = DSSP(structure[0], pdb_file, acc_array="Wilke")
        ss = "".join([aa[2] for aa in dssp])
        sasa = [residues[aa[1]] * aa[3] for aa in dssp]
        builder = PPBuilder()
        seq = ""
        for chain in builder.build_peptides(structure, aa_only=aa_only):
            seq += chain.get_sequence()
        return name, seq, ss, sasa, structure

    def get_enhanced(self, ss, pattern):
        sequence = ""
        for index, letter in enumerate(ss):
            sasa = pattern[index]
            if int(sasa) == 0:
                if letter == 'H':
                    sequence += 'I'
                elif letter == 'B':
                    sequence += 'J'
                elif letter == 'E':
                    sequence += 'K'
                elif letter == 'G':
                    sequence += 'L'
                elif letter == 'I':
                    sequence += 'M'
                elif letter == 'T':
                    sequence += 'N'
                elif letter == 'S':
                    sequence += 'O'
                else:
                    sequence += 'P'
            else:
                if letter == 'H':
                    sequence += 'A'
                elif letter == 'B':
                    sequence += 'B'
                elif letter == 'E':
                    sequence += 'C'
                elif letter == 'G':
                    sequence += 'D'
                elif letter == 'I':
                    sequence += 'E'
                elif letter == 'T':
                    sequence += 'F'
                elif letter == 'S':
                    sequence += 'G'
                else:
                    sequence += 'H'
        return sequence

    def start(self):
        if not self.setup():
            return
        self.connect()

        if not self.args.source is None:
            self.load_source_pdb()
        elif self.args.pattern is not None and self.args.structure is not None:
            self.process_patterning()
        else:
            self.p.print_help()
            return

    def load_source_pdb(self):
        source_file = self.args.source
        base_name = os.path.basename(source_file)
        name, _ = os.path.splitext(base_name)
        _, seq, ss, sasa, structure = self.get_secondary_structure_details(
            name, source_file)
        if self.args.start != -1 and self.args.end != -1 and not self.args.chain:
            seq_start = self.args.start - 1
            fragment_length = self.args.end - self.args.start
            seq = seq[seq_start:self.args.end + 1]
            ss = ss[seq_start:self.args.end + 1]
            sasa = sasa[seq_start:self.args.end + 1]
        else:
            start = self.args.start
            end = self.args.end
            fragment_length = end - start
            chain_id = self.args.chain
            seq_start = get_sequence_position(structure, chain_id, start, end)
            if seq_start == -1:
                self.p.error(
                    "Unable to get the sequence position from the chain identifier , start position and end position"
                )
                self.p.print_help()
                return
            seq = seq[seq_start:seq_start + fragment_length + 1]
            ss = ss[seq_start:seq_start + fragment_length + 1]
            sasa = sasa[seq_start:seq_start + fragment_length + 1]
        asa = [1 if a >= self.args.cutoff else 0
               for a in sasa] if self.args.pattern is None else [
                   int(x) for x in self.args.pattern
               ]
        common_sequence = self.get_enhanced(ss, asa)
        self.common_length = len(common_sequence)
        found_rows, incoming_headers = self.search_by_common_pattern(
            common_sequence)
        incoming_headers_set = set(incoming_headers)
        if len(found_rows) <= 0:
            print("No Records found.")
            return
        if self.args.fuzzy and len(found_rows) > 0:
            new_matches = []
            deleting = []
            for row in found_rows:
                matches = find_near_matches(
                    common_sequence,
                    row.enhanced,
                    max_l_dist=self.args.distance,
                    max_deletions=self.args.deletions,
                    max_insertions=self.args.insertions,
                    max_substitutions=self.args.substitutions)
                if len(matches) > 0:
                    repeats = 1
                    for match in matches:
                        if repeats <= 1:
                            setattr(row, "pos", match.start)
                            setattr(row, "end_pos", match.end)
                        else:
                            keys = row.to_dict()
                            ngram = keys['ngram']
                            del keys['ngram']
                            new_model = ProteinModel(**keys)
                            setattr(new_model, "pos", match.start)
                            setattr(new_model, "end_pos", match.end)
                            setattr(new_model, "ngram", ngram)
                            new_matches.append(new_model)
                        repeats += 1
                else:
                    deleting.append(row)
            found_rows.extend(new_matches)
            for todelete in deleting:
                found_rows.remove(todelete)
        print("Calculating Elenated Score values. Please Wait....")
        deviated_rows = self.calculate_elenated_topology_score(
            seq, found_rows, seq_start, fragment_length)
        deviated_rows = sorted([x for x in deviated_rows if x.rmsd > -1],
                               key=lambda x: x.rmsd,
                               reverse=False)
        proper_headers_set = [
            "protein_id", "pos", "chain", "chain_pos", "end_pos",
            "chain_length", "deviation", "rmsd"
        ]
        if self.args.fuzzy:
            proper_headers_set += ['ngram']
        self.print_results(headers=proper_headers_set, rows=deviated_rows)

    def calculate_elenated_topology_score(self, seq, rows, seq_start,
                                          fragment_length):
        try:
            source_structure = self.__get_structure__(self.args.source)
            source_residues = [res for res in source_structure.get_residues()]
            for row in rows:
                setattr(row, "protein_id", str(row.protein_id).upper())
                position = row.pos
                try:
                    target_file = self.get_target_file(row.protein_id)
                except Exception as e:
                    setattr(row, 'rmsd', -1)
                    setattr(row, 'chain_length', -1)
                    setattr(row, 'chain_pos', -1)
                    setattr(row, 'deviation', -1)
                    setattr(row, 'chain', "N/A")
                    continue
                target_structure = self.__get_structure__(target_file)
                target_residues = [
                    res for res in target_structure.get_residues()
                ]
                start_offset_residue = target_residues[position]
                chain, chain_position = get_chain_position(
                    target_structure, position)
                setattr(row, "chain", "{0}".format(chain))
                chain_length = self.get_chain_length(
                    target_structure, start_offset_residue.full_id[2])
                setattr(row, "chain_length", chain_length)
                setattr(row, "chain_pos", chain_position)
                setattr(row, "end_pos", chain_position + fragment_length)
                current_deviation = self.__get_elenated_topology(
                    source_residues[seq_start:(seq_start + fragment_length) +
                                    1], target_residues[position - 1:position +
                                                        fragment_length + 1],
                    len(seq))
                setattr(row, "deviation", current_deviation)
                self.calculate_rmsd_deviation(row,
                                              seq_start,
                                              fragment_length,
                                              aa_only=False)

        except Exception as e:
            print(e.message)
            raise e

        finally:
            return rows

    def get_chain_length(self, target_structure, chain_name):
        length = 0
        for chain in target_structure.get_chains():
            if chain.id == chain_name:
                length = len(chain)
                break
        return length

    def get_chain_polypeptide(self, structure, pps, global_index):
        chain, position_in_chain = get_chain_position(structure, global_index)
        polypeptides = [EPolyPeptide(pp) for pp in pps]
        current_chain = None
        for pp in polypeptides:
            if pp.chain_id == chain:
                current_chain = pp
                break
        return chain, current_chain, position_in_chain

    def calc_rmsd(self, source_atoms, target_atoms):
        from math import sqrt
        import numpy as np
        from Bio.PDB.QCPSuperimposer import QCPSuperimposer
        if len(source_atoms) != len(target_atoms):
            return -1
        source_arr = []
        for atom in source_atoms:
            xyz = [atom.coord[0], atom.coord[1], atom.coord[2]]
            source_arr.append(xyz)
        source_arr = np.array(source_arr)
        target_arr = []
        for atom in target_atoms:
            xyz = [atom.coord[0], atom.coord[1], atom.coord[2]]
            target_arr.append(xyz)
        target_arr = np.array(target_arr)
        sup = QCPSuperimposer()
        sup.set(source_arr, target_arr)
        sup.run()
        return sup.get_rms()

    def calculate_rmsd_deviation(self,
                                 row,
                                 source_position,
                                 fragment_length,
                                 aa_only=False):
        if self.args.source is None:
            setattr(row, "rmsd", -1)
        target_position = row.pos
        source_structure = self.__get_structure__(self.args.source)
        builder = PPBuilder()
        source_pps = [
            x
            for x in builder.build_peptides(source_structure, aa_only=aa_only)
        ]
        source_length = sum([len(x) for x in source_pps])
        source_residues = []
        for pp in source_pps:
            source_residues += [x for x in pp]
        # source_backbones = [atom['CA'] for atom in source_residues[source_position:source_position + fragment_length+1]]
        source_chain_name, source_chain, source_position_in_chain = self.get_chain_polypeptide(
            source_structure, source_pps, source_position)
        source_backbones = [
            atom['CA'] for atom in source_chain[source_position_in_chain -
                                                1:source_position_in_chain +
                                                fragment_length]
        ]
        source_backbone_residues = " ".join([
            x.resname
            for x in source_chain[source_position_in_chain -
                                  1:source_position_in_chain + fragment_length]
        ])
        builder = PPBuilder()
        target_file = self.get_target_file(row.protein_id)
        if target_file is None:
            setattr(row, "rmsd", -1)
            return
        target_structure = self.__get_structure__(target_file)
        target_pps = [
            x
            for x in builder.build_peptides(target_structure, aa_only=aa_only)
        ]
        target_length = sum([len(x) for x in target_pps])
        target_residues = []
        for pp in target_pps:
            target_residues += [x for x in pp]
        # target_backbone = [atom['CA'] for atom in target_residues[target_position:target_position + fragment_length+1]]
        target_chain_name, target_chain, target_position_in_chain = self.get_chain_polypeptide(
            target_structure, target_pps, target_position)
        target_backbone = [
            atom['CA'] for atom in target_chain[target_position_in_chain -
                                                1:target_position_in_chain +
                                                fragment_length]
        ]
        target_backbone_residues = " ".join([
            x.resname
            for x in target_chain[target_position_in_chain -
                                  1:target_position_in_chain + fragment_length]
        ])
        lengths = [source_length, target_length]
        smallest = min(int(item) for item in lengths)
        # find RMSD
        if len(source_backbones) != len(target_backbone):
            setattr(row, 'rmsd', -1)
            return
        # sup = Bio.PDB.Superimposer()
        # sup.set_atoms(source_backbones, target_backbone)
        # sup.apply(source_structure[0].get_atoms())
        # RMSD = round(sup.rms, 4)
        # print("RMSD For : {0} - {1} : {2}".format(source_backbone_residues,target_backbone_residues,RMSD))
        RMSD = self.calc_rmsd(source_backbones, target_backbone)
        RMSD = round(RMSD, 4)
        setattr(row, "rmsd", RMSD)

    def calculate_RMSD(self,
                       row,
                       source_position,
                       fragment_length,
                       aa_only=False):
        if self.args.source is None:
            setattr(row, "rmsd", -1)
        target_position = row.pos
        source_structure = self.__get_structure__(self.args.source)
        builder = PPBuilder()
        type1 = builder.build_peptides(source_structure, aa_only=aa_only)
        length1 = type1[-1][-1].get_full_id()[3][1]
        fixed_residues = []
        for pp in type1:
            fixed_residues += [x for x in pp]
        fixed = [atom['CA'] for atom in fixed_residues
                 ][source_position:source_position + fragment_length]
        builder = PPBuilder()
        target_file = self.get_target_file(row.protein_id)
        if target_file is None:
            setattr(row, "rmsd", -1)
            return
        target_structure = self.__get_structure__(target_file)
        type2 = builder.build_peptides(target_structure, aa_only=aa_only)
        length2 = type2[-1][-1].get_full_id()[3][1]
        moving_residues = []
        for pp in type2:
            moving_residues += [x for x in pp]
        moving = [atom['CA'] for atom in moving_residues
                  ][target_position:target_position + fragment_length]
        lengths = [length1, length2]
        smallest = min(int(item) for item in lengths)
        # find RMSD
        if len(fixed) != len(moving):
            setattr(row, "rmsd", -1)
            return
        sup = Bio.PDB.Superimposer()
        sup.set_atoms(fixed, moving)
        sup.apply(target_structure[0].get_atoms())
        RMSD = round(sup.rms, 4)
        setattr(row, "rmsd", RMSD)

    def __get_structure__(self, file_path):
        base_name = os.path.basename(file_path)
        name, ext = os.path.splitext(base_name)
        if 'cif' in ext:
            parser = MMCIFParser()
        else:
            parser = PDBParser()
        return parser.get_structure(name, file_path)

    def get_target_file(self, protein_id):
        if not self.args.pdbs is None:
            which_file = os.path.join(self.args.pdbs, protein_id)
            if os.path.exists(which_file + ".pdb"):
                return which_file + ".pdb"
            elif os.path.exists(which_file + ".cif"):
                return which_file + ".cif"
            else:
                return which_file + ".pdb"
        elif os.path.exists(
                os.path.join(self.args.workdir, "{0}.pdb".format(protein_id))):
            return os.path.join(self.args.workdir,
                                "{0}.pdb".format(protein_id))
        else:
            print("Downloading File : {0}".format(protein_id))
            download_url = "https://files.rcsb.org/download/{0}.pdb".format(
                protein_id)
            response = urllib2.urlopen(download_url)
            output_file = os.path.join(self.args.workdir,
                                       "{0}.pdb".format(protein_id))
            with open(output_file, mode='w') as output:
                output.write(response.read())
            print("Downloaded.")
            return output_file

    def get_phi(self, previous, source_res):
        try:
            C_1 = previous['C'].get_vector()
            N = source_res['N'].get_vector()
            CA = source_res['CA'].get_vector()
            C = source_res['C'].get_vector()
            return degrees(calc_dihedral(C_1, N, CA, C))
        except Exception as e:
            return 0.0

    def get_psi(self, target_res, next_res):
        try:
            N = target_res['N'].get_vector()
            CA = target_res['CA'].get_vector()
            C = target_res['C'].get_vector()
            N1_1 = next_res['N'].get_vector()
            return degrees(calc_dihedral(N, CA, C, N1_1))
        except Exception as e:
            return 0.0

    def __get_elenated_topology(self, source_residues, target_residues,
                                seq_length):
        """
        This method will calculate the elenated topology mean square deviation
        :param source_residues: Fragment source residues
        :param target_residues: Target protein residues
        :return: Elenated T(msd) Value between these two proteins
        """
        # target residues should be longer than the source residues in most of cases
        if len(source_residues) > len(target_residues):
            return 0.0
        deviation = 0.0
        total = 0.0
        for index, res in enumerate(source_residues[:-1]):
            if index == 0:
                continue
            source_res = source_residues[index]
            target_res = target_residues[index]
            # calculate Phi and Psi torsional angles for the current residues
            source_phi = self.get_phi(source_residues[index - 1], source_res)
            target_phi = self.get_phi(target_residues[index - 1], target_res)
            source_psi = self.get_psi(source_res, source_residues[index + 1])
            target_psi = self.get_psi(target_res, target_residues[index + 1])
            deviation += atan2(sin(source_phi - target_phi),
                               cos(source_phi - target_phi)) + atan2(
                                   sin(source_psi - target_psi),
                                   cos(source_psi - target_psi))
            total += 1
        if total == 0:
            return 0.0
        else:
            return (abs(deviation) / float(seq_length)) * 100.0

    def process_patterning(self):
        ss = self.args.structure
        pattern = self.args.pattern
        if len(ss) != len(pattern):
            print(
                "Length of Both Secondary Structure and Binary Patterning should equal."
            )
            self.p.print_help()
            return
        common_sequence = self.get_enhanced(ss, pattern)
        self.common_length = len(common_sequence)
        self.search_by_common_pattern(common_sequence)
class clickpandas:
    def __init__(self, config_path):

        self.config_path = config_path

        with open(os.path.expanduser(self.config_path), 'r') as fp:
            self.config = json.load(fp)

        self.db_name = self.config['db_name']
        self.db_url = self.config['db_url']
        self.username = self.config['username']
        self.password = self.config['password']
        self.uri = 'clickhouse://' + self.username + ':' + self.password + '@' + self.db_url.split(
            '//')[1] + '/' + self.db_name
        self.db = Database(self.db_name,
                           db_url=self.db_url,
                           username=self.username,
                           password=self.password)

        self.technical_attr = {
            '__class__', '__delattr__', '__dict__', '__dir__', '__doc__',
            '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__',
            '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__',
            '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__',
            '__repr__', '__setattr__', '__sizeof__', '__str__',
            '__subclasshook__', '__weakref__', '_database', '_fields',
            '_writable_fields', 'create_table_sql', 'drop_table_sql', 'engine',
            'from_tsv', 'get_database', 'get_field', 'objects_in', 'readonly',
            'set_database', 'system', 'table_name', 'to_dict', 'to_tsv'
        }

        self.dt_pattern = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}'
        self.col_pattern = r'select(.*?)from'
        self.table_pattern = r'from (.*?) '

    def insert_rows(self, data_buffer, datamodel):

        self.db.create_table(datamodel)
        try:
            self.db.insert(data_buffer)
        except:
            traceback.print_exc()

    def execute_query(self, f, **kwargs):

        data_buffer = []

        for row in self.db.select(f(**kwargs)):
            quoted_attributes = [
                getattr(row, attr) for attr in dir(row)
                if attr not in self.technical_attr
            ]
            data_buffer.append(quoted_attributes)

        if 'row' in dir():
            columns = [
                attr for attr in dir(row) if attr not in self.technical_attr
            ]
        else:
            columns = ['no_data']

        return pd.DataFrame(data_buffer, columns=columns)

    def proxy_execute(self, template, start, end):
        def fill_template(template, start, end):

            return template.format(start, end)

        result = self.execute(fill_template,
                              template=template,
                              start=start,
                              end=end)

        return result

    def __get_time_splits(self, query, n_splits=2):

        timestamps = [datetime.datetime.strptime(x.group(), "%Y-%m-%d %H:%M:%S") \
        for x in re.finditer(self.dt_pattern, query)]
        template = re.sub(self.dt_pattern, '{}', query)
        start = min(timestamps)
        diff = (max(timestamps) - start) / n_splits
        diff -= datetime.timedelta(microseconds=diff.microseconds)

        for i in range(1, n_splits + 1):

            yield (template, start + diff * (i - 1), start + diff * i)

    def parallel_execute(self, f, n_splits=2, **kwargs):

        query = f(**kwargs)
        executor = ProcessPoolExecutor(n_splits)
        futures = [executor.submit(self.proxy_execute, date_range[0], date_range[1], date_range[2]) \
        for date_range in self.__get_time_splits(query, n_splits)]

        df = pd.DataFrame()
        for proc in as_completed(futures):
            df = pd.concat([df, proc.result()])

        return df

    def out_of_core_execute(self,
                            timestamp_column,
                            f,
                            blocksize=268435456,
                            **kwargs):

        query = f(**kwargs)
        timestamps = [[str(date_range[1]),
                       str(date_range[2])]
                      for date_range in self.__get_time_splits(query, 1)][0]
        columns = re.findall(self.col_pattern, query)[0].replace(' ',
                                                                 '').split(',')
        table = re.findall(self.table_pattern, query)[0].split('.')[1]
        df = dd.read_sql_table(table=table,
                               divisions=timestamps,
                               uri=self.uri,
                               columns=columns,
                               index_col=timestamp_column,
                               bytes_per_chunk=blocksize)

        return df
class DatabaseTestCase(unittest.TestCase):

    def setUp(self):
        self.database = Database('test_db')
        self.database.create_table(Person)

    def tearDown(self):
        self.database.drop_table(Person)
        self.database.drop_database()

    def _insert_and_check(self, data, count):
        self.database.insert(data)
        self.assertEquals(count, self.database.count(Person))

    def test_insert__generator(self):
        self._insert_and_check(self._sample_data(), len(data))

    def test_insert__list(self):
        self._insert_and_check(list(self._sample_data()), len(data))

    def test_insert__iterator(self):
        self._insert_and_check(iter(self._sample_data()), len(data))

    def test_insert__empty(self):
        self._insert_and_check([], 0)

    def test_count(self):
        self.database.insert(self._sample_data())
        self.assertEquals(self.database.count(Person), 100)
        self.assertEquals(self.database.count(Person, "first_name = 'Courtney'"), 2)
        self.assertEquals(self.database.count(Person, "birthday > '2000-01-01'"), 22)
        self.assertEquals(self.database.count(Person, "birthday < '1970-03-01'"), 0)

    def test_select(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT * FROM test_db.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query, Person))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 1.72)
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 1.70)

    def test_select_partial_fields(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT first_name, last_name FROM test_db.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query, Person))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 0) # default value
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 0) # default value

    def test_select_ad_hoc_model(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT * FROM test_db.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].__class__.__name__, 'AdHocModel')
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 1.72)
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 1.70)

    def _sample_data(self):
        for entry in data:
            yield Person(**entry)
Exemple #25
0
class ClickHouseDataWrapper(ForeignDataWrapper):

    def __init__(self, options, columns):
        super(ClickHouseDataWrapper, self).__init__(options, columns)
        # TODO add username, password, debug
        self.db_name    = options.get('db_name', 'default')
        self.db_url     = options.get('db_url', 'http://localhost:8123/')
        self.db         = Database(self.db_name, self.db_url)
        self.table_name = options['table_name']
        self.model      = self._build_model()
        self.column_stats = self._get_column_stats(columns)

    def _build_model(self):
        sql = "SELECT name, type FROM system.columns where database='%s' and table='%s'" % (self.db_name, self.table_name)
        cols = [(row.name, row.type) for row in self.db.select(sql)]
        return ModelBase.create_ad_hoc_model(cols, model_name=self.table_name)

    def can_sort(self, sortkeys):
        return sortkeys

    def get_rel_size(self, quals, columns):
        qs = self._build_query(quals, columns)
        total_size = sum(self.column_stats[c]['size'] for c in columns)
        ret = (qs.count(), total_size)
        return ret

    def get_path_keys(self):
        return [((name,), stats['average_rows']) for name, stats in self.column_stats.items()]

    def execute(self, quals, columns, sortkeys=None):
        qs = self._build_query(quals, columns, sortkeys)
        log_to_postgres(qs.as_sql())
        for instance in qs:
            yield instance.to_dict(field_names=columns)

    def explain(self, quals, columns, sortkeys=None, verbose=False):
        qs = self._build_query(quals, columns, sortkeys)
        return qs.as_sql().split('\n')

    def _build_query(self, quals, columns, sortkeys=None):
        columns = columns or [self._get_smallest_column()] # use a small column when PostgreSQL doesn't need any columns
        qs = self.model.objects_in(self.db).only(*columns)
        if sortkeys:
            order = ['-' + sk.attname if sk.is_reversed else sk.attname for sk in sortkeys]
            qs = qs.order_by(*order)
        for qual in quals:
            operator = OPERATORS.get(qual.operator)
            if operator:
                qs = qs.filter(**{qual.field_name + '__' + operator: qual.value})
            else:
                self._warn('Qual not pushed to ClickHouse: %s' % qual)
        return qs

    def _get_column_stats(self, columns):
        column_stats = {}
        # Get total number of rows
        total_rows = self.model.objects_in(self.db).count()
        # Get average rows per value in column (total divided by number of unique values)
        exprs = ['intDiv(%d, uniqCombined(%s)) as %s' % (total_rows, c, c) for c in columns]
        sql = "SELECT %s FROM $db.`%s`" % (', '.join(exprs), self.table_name)
        for row in self.db.select(sql):
            for c in columns:
                column_stats[c] = dict(average_rows=getattr(row, c), size=4)
        # Get average size per column. This may fail because data_uncompressed_bytes is a recent addition
        sql = "SELECT * FROM system.columns WHERE database='%s' AND table='%s'" % (self.db_name, self.table_name)
        for col_def in self.db.select(sql):
            column_stats[col_def.name]['size'] = self._calc_col_size(col_def, total_rows) 
        # Debug
        for c in columns:
            log_to_postgres(c + ': ' + repr(column_stats[c]))
        return column_stats

    def _calc_col_size(self, col_def, total_rows):
        size = 0
        if col_def.type in COLUMN_SIZES:
            # A column with a fixed size
            size = COLUMN_SIZES[col_def.type]
        elif hasattr(col_def, 'data_uncompressed_bytes'):
            # Non fixed size, calculate average size
            size = int(float(col_def.data_uncompressed_bytes) / total_rows)
        elif hasattr(col_def, 'bytes'):
            # Assume x10 compression and calculate average size
            size = int(float(col_def.bytes) * 10 / total_rows)
        return size or 8 

    def _get_smallest_column(self):
        item = min(self.column_stats.items(), key=lambda item: item[1]['size'])
        return item[0]
        
    @classmethod
    def import_schema(cls, schema, srv_options, options, restriction_type, restricts):
        db_name = options.get('db_name', 'default')
        db_url  = options.get('db_url', 'http://localhost:8123/')
        db      = Database(db_name, db_url)
        tables  = cls._tables_to_import(db, restriction_type, restricts)
        return [cls._import_table(db, table, options) for table in tables]

    @classmethod
    def _tables_to_import(cls, db, restriction_type, restricts):
        sql = "SELECT name FROM system.tables WHERE database='%s'" % db.db_name
        if restriction_type:
            op = 'IN' if restriction_type == 'limit' else 'NOT IN'
            names = ', '.join("'%s'" % name for name in restricts)
            sql += ' AND name %s (%s)' % (op, names)
        return [row.name for row in db.select(sql)]

    @classmethod
    def _import_table(cls, db, table, options):
        columns = []
        sql = "SELECT name, type FROM system.columns where database='%s' and table='%s'" % (db.db_name, table)
        for row in db.select(sql):
            try:
                columns.append(ColumnDefinition(row.name, type_name=_convert_column_type(row.type)))
            except KeyError:
                cls._warn('Unsupported column type %s in table %s was skipped' % (row.type, table))
        merged_options = dict(options, table_name=table)
        return TableDefinition(table, columns=columns, options=merged_options)

    @classmethod
    def _warn(cls, msg):
        log_to_postgres(msg, WARNING)
class DatabaseTestCase(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(Person)

    def tearDown(self):
        self.database.drop_table(Person)
        self.database.drop_database()

    def _insert_and_check(self, data, count):
        self.database.insert(data)
        self.assertEquals(count, self.database.count(Person))

    def test_insert__generator(self):
        self._insert_and_check(self._sample_data(), len(data))

    def test_insert__list(self):
        self._insert_and_check(list(self._sample_data()), len(data))

    def test_insert__iterator(self):
        self._insert_and_check(iter(self._sample_data()), len(data))

    def test_insert__empty(self):
        self._insert_and_check([], 0)

    def test_count(self):
        self.database.insert(self._sample_data())
        self.assertEquals(self.database.count(Person), 100)
        self.assertEquals(self.database.count(Person, "first_name = 'Courtney'"), 2)
        self.assertEquals(self.database.count(Person, "birthday > '2000-01-01'"), 22)
        self.assertEquals(self.database.count(Person, "birthday < '1970-03-01'"), 0)

    def test_select(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query, Person))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 1.72)
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 1.70)

    def test_select_partial_fields(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT first_name, last_name FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query, Person))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 0) # default value
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 0) # default value

    def test_select_ad_hoc_model(self):
        self._insert_and_check(self._sample_data(), len(data))
        query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name"
        results = list(self.database.select(query))
        self.assertEquals(len(results), 2)
        self.assertEquals(results[0].__class__.__name__, 'AdHocModel')
        self.assertEquals(results[0].last_name, 'Durham')
        self.assertEquals(results[0].height, 1.72)
        self.assertEquals(results[1].last_name, 'Scott')
        self.assertEquals(results[1].height, 1.70)

    def test_pagination(self):
        self._insert_and_check(self._sample_data(), len(data))
        # Try different page sizes
        for page_size in (1, 2, 7, 10, 30, 100, 150):
            # Iterate over pages and collect all intances
            page_num = 1
            instances = set()
            while True:
                page = self.database.paginate(Person, 'first_name, last_name', page_num, page_size)
                self.assertEquals(page.number_of_objects, len(data))
                self.assertGreater(page.pages_total, 0)
                [instances.add(obj.to_tsv()) for obj in page.objects]
                if page.pages_total == page_num:
                    break
                page_num += 1
            # Verify that all instances were returned
            self.assertEquals(len(instances), len(data))

    def _sample_data(self):
        for entry in data:
            yield Person(**entry)
class DateTimeFieldWithTzTest(unittest.TestCase):
    def setUp(self):
        self.database = Database('test-db', log_statements=True)
        if self.database.server_version < (20, 1, 2, 4):
            raise unittest.SkipTest('ClickHouse version too old')
        self.database.create_table(ModelWithTz)

    def tearDown(self):
        self.database.drop_database()

    def test_ad_hoc_model(self):
        self.database.insert([
            ModelWithTz(
                datetime_no_tz_field='2020-06-11 04:00:00',
                datetime_tz_field='2020-06-11 04:00:00',
                datetime64_tz_field='2020-06-11 04:00:00',
                datetime_utc_field='2020-06-11 04:00:00',
            ),
            ModelWithTz(
                datetime_no_tz_field='2020-06-11 07:00:00+0300',
                datetime_tz_field='2020-06-11 07:00:00+0300',
                datetime64_tz_field='2020-06-11 07:00:00+0300',
                datetime_utc_field='2020-06-11 07:00:00+0300',
            ),
        ])
        query = 'SELECT * from $db.modelwithtz ORDER BY datetime_no_tz_field'
        results = list(self.database.select(query))

        self.assertEqual(
            results[0].datetime_no_tz_field,
            datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[0].datetime_tz_field,
            datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[0].datetime64_tz_field,
            datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[0].datetime_utc_field,
            datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[1].datetime_no_tz_field,
            datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[1].datetime_tz_field,
            datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[1].datetime64_tz_field,
            datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))
        self.assertEqual(
            results[1].datetime_utc_field,
            datetime.datetime(2020, 6, 11, 4, 0, 0, tzinfo=pytz.UTC))

        self.assertEqual(results[0].datetime_no_tz_field.tzinfo.zone,
                         self.database.server_timezone.zone)
        self.assertEqual(results[0].datetime_tz_field.tzinfo.zone,
                         pytz.timezone('Europe/Madrid').zone)
        self.assertEqual(results[0].datetime64_tz_field.tzinfo.zone,
                         pytz.timezone('Europe/Madrid').zone)
        self.assertEqual(results[0].datetime_utc_field.tzinfo.zone,
                         pytz.timezone('UTC').zone)
        self.assertEqual(results[1].datetime_no_tz_field.tzinfo.zone,
                         self.database.server_timezone.zone)
        self.assertEqual(results[1].datetime_tz_field.tzinfo.zone,
                         pytz.timezone('Europe/Madrid').zone)
        self.assertEqual(results[1].datetime64_tz_field.tzinfo.zone,
                         pytz.timezone('Europe/Madrid').zone)
        self.assertEqual(results[1].datetime_utc_field.tzinfo.zone,
                         pytz.timezone('UTC').zone)
Exemple #28
0
class NullableFieldsTest(unittest.TestCase):
    def setUp(self):
        self.database = Database('test-db')
        self.database.create_table(ModelWithNullable)

    def tearDown(self):
        self.database.drop_database()

    def test_nullable_datetime_field(self):
        f = NullableField(DateTimeField())
        epoch = datetime(1970, 1, 1, tzinfo=pytz.utc)
        # Valid values
        for value in (date(1970, 1, 1), datetime(1970, 1, 1), epoch,
                      epoch.astimezone(pytz.timezone('US/Eastern')),
                      epoch.astimezone(pytz.timezone('Asia/Jerusalem')),
                      '1970-01-01 00:00:00', '1970-01-17 00:00:17',
                      '0000-00-00 00:00:00', 0, '\\N'):
            dt = f.to_python(value, pytz.utc)
            if value == '\\N':
                self.assertIsNone(dt)
            else:
                self.assertEqual(dt.tzinfo, pytz.utc)
            # Verify that conversion to and from db string does not change value
            dt2 = f.to_python(f.to_db_string(dt, quote=False), pytz.utc)
            self.assertEqual(dt, dt2)
        # Invalid values
        for value in ('nope', '21/7/1999', 0.5):
            with self.assertRaises(ValueError):
                f.to_python(value, pytz.utc)

    def test_nullable_uint8_field(self):
        f = NullableField(UInt8Field())
        # Valid values
        for value in (17, '17', 17.0, '\\N'):
            python_value = f.to_python(value, pytz.utc)
            if value == '\\N':
                self.assertIsNone(python_value)
                self.assertEqual(value, f.to_db_string(python_value))
            else:
                self.assertEqual(python_value, 17)

        # Invalid values
        for value in ('nope', date.today()):
            with self.assertRaises(ValueError):
                f.to_python(value, pytz.utc)

    def test_nullable_string_field(self):
        f = NullableField(StringField())
        # Valid values
        for value in ('\\\\N', 'N', 'some text', '\\N'):
            python_value = f.to_python(value, pytz.utc)
            if value == '\\N':
                self.assertIsNone(python_value)
                self.assertEqual(value, f.to_db_string(python_value))
            else:
                self.assertEqual(python_value, value)

    def test_isinstance(self):
        for field in (StringField, UInt8Field, Float32Field, DateTimeField):
            f = NullableField(field())
            self.assertTrue(f.isinstance(field))
            self.assertTrue(f.isinstance(NullableField))
        for field in (Int8Field, Int16Field, Int32Field, Int64Field,
                      UInt8Field, UInt16Field, UInt32Field, UInt64Field):
            f = NullableField(field())
            self.assertTrue(f.isinstance(BaseIntField))
        for field in (Float32Field, Float64Field):
            f = NullableField(field())
            self.assertTrue(f.isinstance(BaseFloatField))
        f = NullableField(NullableField(UInt32Field()))
        self.assertTrue(f.isinstance(BaseIntField))
        self.assertTrue(f.isinstance(NullableField))
        self.assertFalse(f.isinstance(BaseFloatField))

    def _insert_sample_data(self):
        dt = date(1970, 1, 1)
        self.database.insert([
            ModelWithNullable(date_field='2016-08-30',
                              null_str='',
                              null_int=42,
                              null_date=dt),
            ModelWithNullable(date_field='2016-08-30',
                              null_str='nothing',
                              null_int=None,
                              null_date=None),
            ModelWithNullable(date_field='2016-08-31',
                              null_str=None,
                              null_int=42,
                              null_date=dt),
            ModelWithNullable(date_field='2016-08-31',
                              null_str=None,
                              null_int=None,
                              null_date=None)
        ])

    def _assert_sample_data(self, results):
        dt = date(1970, 1, 1)
        self.assertEqual(len(results), 4)
        self.assertIsNone(results[0].null_str)
        self.assertEqual(results[0].null_int, 42)
        self.assertEqual(results[0].null_date, dt)
        self.assertIsNone(results[1].null_date)
        self.assertEqual(results[1].null_str, 'nothing')
        self.assertIsNone(results[1].null_date)
        self.assertIsNone(results[2].null_str)
        self.assertEqual(results[2].null_date, dt)
        self.assertEqual(results[2].null_int, 42)
        self.assertIsNone(results[3].null_int)
        self.assertIsNone(results[3].null_str)
        self.assertIsNone(results[3].null_date)

    def test_insert_and_select(self):
        self._insert_sample_data()
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, ModelWithNullable))
        self._assert_sample_data(results)

    def test_ad_hoc_model(self):
        self._insert_sample_data()
        query = 'SELECT * from $db.modelwithnullable ORDER BY date_field'
        results = list(self.database.select(query))
        self._assert_sample_data(results)
class MigrationsTestCase(unittest.TestCase):
    def setUp(self):
        self.database = Database('test-db', log_statements=True)
        self.database.drop_table(MigrationHistory)

    def tearDown(self):
        self.database.drop_database()

    def table_exists(self, model_class):
        query = "EXISTS TABLE $db.`%s`" % model_class.table_name()
        return next(self.database.select(query)).result == 1

    def get_table_fields(self, model_class):
        query = "DESC `%s`.`%s`" % (self.database.db_name,
                                    model_class.table_name())
        return [(row.name, row.type) for row in self.database.select(query)]

    def get_table_def(self, model_class):
        return self.database.raw('SHOW CREATE TABLE $db.`%s`' %
                                 model_class.table_name())

    def test_migrations(self):
        # Creation and deletion of table
        self.database.migrate('tests.sample_migrations', 1)
        self.assertTrue(self.table_exists(Model1))
        self.database.migrate('tests.sample_migrations', 2)
        self.assertFalse(self.table_exists(Model1))
        self.database.migrate('tests.sample_migrations', 3)
        self.assertTrue(self.table_exists(Model1))
        # Adding, removing and altering simple fields
        self.assertEqual(self.get_table_fields(Model1), [('date', 'Date'),
                                                         ('f1', 'Int32'),
                                                         ('f2', 'String')])
        self.database.migrate('tests.sample_migrations', 4)
        self.assertEqual(self.get_table_fields(Model2),
                         [('date', 'Date'), ('f1', 'Int32'), ('f3', 'Float32'),
                          ('f2', 'String'), ('f4', 'String'),
                          ('f5', 'Array(UInt64)')])
        self.database.migrate('tests.sample_migrations', 5)
        self.assertEqual(self.get_table_fields(Model3), [('date', 'Date'),
                                                         ('f1', 'Int64'),
                                                         ('f3', 'Float64'),
                                                         ('f4', 'String')])
        # Altering enum fields
        self.database.migrate('tests.sample_migrations', 6)
        self.assertTrue(self.table_exists(EnumModel1))
        self.assertEqual(self.get_table_fields(EnumModel1),
                         [('date', 'Date'),
                          ('f1', "Enum8('dog' = 1, 'cat' = 2, 'cow' = 3)")])
        self.database.migrate('tests.sample_migrations', 7)
        self.assertTrue(self.table_exists(EnumModel1))
        self.assertEqual(
            self.get_table_fields(EnumModel2),
            [('date', 'Date'),
             ('f1', "Enum16('dog' = 1, 'cat' = 2, 'horse' = 3, 'pig' = 4)")])
        # Materialized fields and alias fields
        self.database.migrate('tests.sample_migrations', 8)
        self.assertTrue(self.table_exists(MaterializedModel))
        self.assertEqual(self.get_table_fields(MaterializedModel),
                         [('date_time', "DateTime"), ('date', 'Date')])
        self.database.migrate('tests.sample_migrations', 9)
        self.assertTrue(self.table_exists(AliasModel))
        self.assertEqual(self.get_table_fields(AliasModel),
                         [('date', 'Date'), ('date_alias', "Date")])
        # Buffer models creation and alteration
        self.database.migrate('tests.sample_migrations', 10)
        self.assertTrue(self.table_exists(Model4))
        self.assertTrue(self.table_exists(Model4Buffer))
        self.assertEqual(self.get_table_fields(Model4), [('date', 'Date'),
                                                         ('f1', 'Int32'),
                                                         ('f2', 'String')])
        self.assertEqual(self.get_table_fields(Model4Buffer),
                         [('date', 'Date'), ('f1', 'Int32'), ('f2', 'String')])
        self.database.migrate('tests.sample_migrations', 11)
        self.assertEqual(self.get_table_fields(Model4), [('date', 'Date'),
                                                         ('f3', 'DateTime'),
                                                         ('f2', 'String')])
        self.assertEqual(self.get_table_fields(Model4Buffer),
                         [('date', 'Date'), ('f3', 'DateTime'),
                          ('f2', 'String')])

        self.database.migrate('tests.sample_migrations', 12)
        self.assertEqual(self.database.count(Model3), 3)
        data = [
            item.f1 for item in self.database.select(
                'SELECT f1 FROM $table ORDER BY f1', model_class=Model3)
        ]
        self.assertListEqual(data, [1, 2, 3])

        self.database.migrate('tests.sample_migrations', 13)
        self.assertEqual(self.database.count(Model3), 4)
        data = [
            item.f1 for item in self.database.select(
                'SELECT f1 FROM $table ORDER BY f1', model_class=Model3)
        ]
        self.assertListEqual(data, [1, 2, 3, 4])

        self.database.migrate('tests.sample_migrations', 14)
        self.assertTrue(self.table_exists(MaterializedModel1))
        self.assertEqual(self.get_table_fields(MaterializedModel1),
                         [('date_time', 'DateTime'), ('int_field', 'Int8'),
                          ('date', 'Date'), ('int_field_plus_one', 'Int8')])
        self.assertTrue(self.table_exists(AliasModel1))
        self.assertEqual(self.get_table_fields(AliasModel1),
                         [('date', 'Date'), ('int_field', 'Int8'),
                          ('date_alias', 'Date'),
                          ('int_field_plus_one', 'Int8')])
        # Codecs and low cardinality
        self.database.migrate('tests.sample_migrations', 15)
        self.assertTrue(self.table_exists(Model4_compressed))
        if self.database.has_low_cardinality_support:
            self.assertEqual(self.get_table_fields(Model2LowCardinality),
                             [('date', 'Date'),
                              ('f1', 'LowCardinality(Int32)'),
                              ('f3', 'LowCardinality(Float32)'),
                              ('f2', 'LowCardinality(String)'),
                              ('f4', 'LowCardinality(Nullable(String))'),
                              ('f5', 'Array(LowCardinality(UInt64))')])
        else:
            logging.warning('No support for low cardinality')
            self.assertEqual(self.get_table_fields(Model2),
                             [('date', 'Date'), ('f1', 'Int32'),
                              ('f3', 'Float32'), ('f2', 'String'),
                              ('f4', 'Nullable(String)'),
                              ('f5', 'Array(UInt64)')])

        if self.database.server_version >= (19, 14, 3, 3):
            # Creating constraints
            self.database.migrate('tests.sample_migrations', 16)
            self.assertTrue(self.table_exists(ModelWithConstraints))
            self.database.insert([ModelWithConstraints(f1=101, f2='a')])
            with self.assertRaises(ServerError):
                self.database.insert([ModelWithConstraints(f1=99, f2='a')])
            with self.assertRaises(ServerError):
                self.database.insert([ModelWithConstraints(f1=101, f2='x')])
            # Modifying constraints
            self.database.migrate('tests.sample_migrations', 17)
            self.database.insert([ModelWithConstraints(f1=99, f2='a')])
            with self.assertRaises(ServerError):
                self.database.insert([ModelWithConstraints(f1=101, f2='a')])
            with self.assertRaises(ServerError):
                self.database.insert([ModelWithConstraints(f1=99, f2='x')])

        if self.database.server_version >= (20, 1, 2, 4):
            # Creating indexes
            self.database.migrate('tests.sample_migrations', 18)
            self.assertTrue(self.table_exists(ModelWithIndex))
            self.assertIn('INDEX index ', self.get_table_def(ModelWithIndex))
            self.assertIn('INDEX another_index ',
                          self.get_table_def(ModelWithIndex))
            # Modifying indexes
            self.database.migrate('tests.sample_migrations', 19)
            self.assertNotIn('INDEX index ',
                             self.get_table_def(ModelWithIndex))
            self.assertIn('INDEX index2 ', self.get_table_def(ModelWithIndex))
            self.assertIn('INDEX another_index ',
                          self.get_table_def(ModelWithIndex))
class Designer(object):
    def __init__(self):
        self.p = argparse.ArgumentParser(
            description="Designer is a program which finds the best sequence that could fold to the given protein model")

    def setup(self):
        self.db = None
        self.args = None
        self.predicted_sequence = ""
        self.p.add_argument("-w", "--workdir", default=".",
                            help="The working directory for this script to save and load temporary. By default it "
                                 "takes the current working directory. "
                                 "memory files required to function properly.")
        self.p.add_argument("-q", "--pdbs", type=str, default=None,
                            help="Local repository containing PDB Files to read from")
        self.p.add_argument("-d", "--database", default="propensities", help="Database Name to use")
        self.p.add_argument("-t", "--host", default="localhost", help="Database Host to use. defaults to localhost")
        self.p.add_argument("-r", "--port", default=8123, help="Database Default port to use")
        self.p.add_argument("-u", "--username", help="Username of the database to use")
        self.p.add_argument("-o", "--password", help="Database password to use to login")
        self.p.add_argument("-i","--window",type=int,default=8,help="The window length to select from the model in each iteration.")
        self.p.add_argument("-s", "--model", type=str, default=None
                            ,
                            help="Source PDB which contains the 3D structure of a protein or part of a protein (Fragment), that you want to search for similars in the database.")
        self.p.add_argument("-m", "--rmsd", type=float, default=-1,
                            help="Return matched proteins with this RMSD value only and exclude any others.")
        self.p.add_argument("-l", "--limit", type=int, default=10,
                            help="Total Number of results to include. Defaults to 10.")
        self.p.add_argument('-c', '--cutoff', default=23.9, type=float, help="Cutoff value to use to mark a residue "
                                                                             "Solvent-accessible surface area as polar or "
                                                                             "hydrophobic, if the SASA of a residue equal "
                                                                             "to or higher than this value it will be "
                                                                             "considered polar otherwise it will be "
                                                                             "marked as hydrophobic")
        self.p.add_argument("-f", "--fuzzy", nargs="?", const=True, default=False, help="Perform Fuzzy Search. Allow "
                                                                                        "matching similar but not "
                                                                                        "identical results.")
        self.p.add_argument("-z", "--fuzzylevel", type=float, default=0.90, help="Include only results with equal or ")
        if len(sys.argv) <= 1:
            self.p.print_help()
            return False
        self.args = self.p.parse_args()
        return True
    def connect(self):
        db_url = "http://{0}:{1}".format(self.args.host or "localhost", self.args.port or 8123)
        self.db = Database(db_name=self.args.database, db_url=db_url, username=self.args.username,
                           password=self.args.password)
        return True

    def start(self):
        if not self.setup():
            return
        self.connect()
        if not self.args.model is None:
            self.load_model()
        else:
            print("Please specify Protein Model to predict sequence for.")
            self.p.print_help()

    def get_secondary_structure_details(self, name, pdb_file, aa_only=False):
        parser = PDBParser()
        structure = parser.get_structure(name, pdb_file)
        dssp = DSSP(structure[0], pdb_file, acc_array="Wilke")
        ss = "".join([aa[2] for aa in dssp])
        sasa = [residues[aa[1]] * aa[3] for aa in dssp]
        builder = PPBuilder()
        seq = ""
        for chain in builder.build_peptides(structure, aa_only=aa_only):
            seq += chain.get_sequence()
        return name, seq, ss, sasa

    def load_model(self):
        model_file = self.args.model
        base_name = os.path.basename(model_file)
        name, _ = os.path.splitext(base_name)
        protein_name, seq, ss, sasa = self.get_secondary_structure_details(name, model_file)
        print("Predicting Sequence. Please Wait....")
        self.predict_sequence(protein_name,seq,ss,sasa)

    # def predict_sequence_offset(self,protein_name,seq,ss,sasa):
    #     pointer = 0
    #     window = self.args.window if self.args.window > 0 else 8
    #     seq_length = len(seq)
    #     iterations = []
    #     while seq_length - pointer > 0:
    #         if seq_length - pointer < window:
    #             window = seq_length - pointer
    #         current_ss = ss[pointer:pointer + window]
    #         current_sasa = sasa[pointer:pointer + window]
    #         current_seq = seq[pointer:pointer + window]
    #         pointer += 1
    #         current_asa = [1 if a >= self.args.cutoff else 0 for a in current_sasa]
    #         common_sequence = self.get_enhanced(current_ss, current_asa)
    #         found_proteins = self.search_by_common_pattern(common_sequence)
    #         if len(found_proteins) <= 0:
    #             continue
    #         if self.args.fuzzy:
    #             continue
    #
    #         for row in found_proteins:
    #             self.calculate_RMSD(row, pointer, window, aa_only=False)
    #         sorted_proteins = sorted([x for x in found_proteins if x.rmsd > -1], key=lambda protein: protein.rmsd,
    #                                  reverse=False)
    #         if len(sorted_proteins) < 1:
    #             continue
    #         best_protein = sorted_proteins[0]
    #         if best_protein.protein_id == protein_name and len(found_proteins) > 1:
    #             best_protein = sorted_proteins[1]
    #         pos = best_protein.pos
    #         current_sequence = best_protein.sequence[pos:pos + len(current_seq) + 1]
    #         if len(current_sequence) < self.args.window:
    #             current_sequence += "".join(["-"]*(self.args.window-len(current_sequence)))
    #         iterations.append(current_sequence)
    #     for row in range(0,len(iterations)):
    #         letters = []
    #         for col in range(0,self.args.window):
    #             letters.append(iterations[row][col])
    #         freqs = Counter(letters)
    #         sorted_letters= sorted(freqs.items(),key=lambda item:item[1],reverse=True)
    #         if len(sorted_letters) > 0:
    #             self.predicted_sequence += sorted_letters[0][0]
    #     print("Sequence Prediction Finished.")
    #     print("Predicted Sequence:")
    #     print(self.predicted_sequence)
    #
    # def get_unique_ss_island(self,ss):
    #     current = ss[0]
    #     island = current
    #     island_positions = [0]
    #     for index in range(1,len(ss)):
    #         if ss[index] == current:
    #             island += ss[index]
    #             island_positions.append(index)
    #         else:
    #             yield island , island_positions
    #             current = ss[index]
    #             island = current
    #             island_positions = [index]
    #     if len(island) > 0 and len(island_positions) > 0:
    #         yield island,island_positions
    #
    # def predict_sequence2(self,protein_name,seq,ss,sasa):
    #     for island, positions in self.get_unique_ss_island(ss):
    #         current_ss = island
    #         current_sasa = [sasa[x] for x in positions]
    #         current_seq = "".join([seq[x] for x in positions])
    #         current_asa = [1 if a >= self.args.cutoff else 0 for a in current_sasa]
    #         common_sequence = self.get_enhanced(current_ss, current_asa)
    #         found_proteins = self.search_by_common_pattern(common_sequence)
    #         if len(found_proteins) <= 0:
    #             continue
    #         if self.args.fuzzy:
    #             continue
    #         pointer = positions[0]
    #         window = positions[-1] - pointer
    #         for row in found_proteins:
    #             self.calculate_RMSD(row, pointer, window, aa_only=False)
    #         sorted_proteins = sorted([x for x in found_proteins if x.rmsd > -1], key=lambda protein: protein.rmsd,
    #                                  reverse=False)
    #         if len(sorted_proteins) < 1:
    #             continue
    #         best_protein = sorted_proteins[0]
    #         if best_protein.protein_id == protein_name and len(found_proteins) > 1:
    #             best_protein = sorted_proteins[1]
    #         pos = best_protein.pos
    #         self.predicted_sequence += best_protein.sequence[pos:pos + len(current_seq) + 1]
    #     print("Sequence Prediction Finished.")
    #     print("Predicted Sequence:")
    #     print(self.predicted_sequence)


    def predict_sequence(self,protein_name,seq,ss,sasa):
        pointer = 0
        window = self.args.window if self.args.window > 0 else 8
        seq_length = len(seq)
        while seq_length - pointer > 0:
            if seq_length - pointer < window:
                window= seq_length - pointer
            current_ss = ss[pointer:pointer+window]
            current_sasa = sasa[pointer:pointer+window]
            current_seq = seq[pointer:pointer+window]
            pointer += window
            current_asa = [1 if a >= self.args.cutoff else 0 for a in current_sasa]
            common_sequence = self.get_enhanced(current_ss,current_asa)
            found_proteins = self.search_by_common_pattern(common_sequence)
            if len(found_proteins) <= 0:
                continue
            if self.args.fuzzy:
                continue

            for row in found_proteins:
                self.calculate_RMSD(row,pointer,window,aa_only=False)
            sorted_proteins = sorted([x for x in found_proteins if x.rmsd > -1],key=lambda protein: protein.rmsd,reverse=False)
            if len(sorted_proteins) < 1:
                continue
            best_protein = sorted_proteins[0]
            if best_protein.protein_id == protein_name and len(found_proteins) > 1:
                best_protein = sorted_proteins[1]
            pos = best_protein.pos
            self.predicted_sequence += best_protein.sequence[pos:pos+len(current_seq)+1]
        print("Sequence Prediction Finished.")
        print("Predicted Sequence:")
        print(self.predicted_sequence)






    def __get_structure__(self, file_path):
        base_name = os.path.basename(file_path)
        name, ext = os.path.splitext(base_name)
        if 'cif' in ext:
            parser = MMCIFParser()
        else:
            parser = PDBParser()
        return parser.get_structure(name, file_path)

    def get_target_file(self, protein_id):
        if not self.args.pdbs is None:
            which_file = os.path.join(self.args.pdbs, protein_id)
            if os.path.exists(which_file + ".pdb"):
                return which_file + ".pdb"
            elif os.path.exists(which_file + ".cif"):
                return which_file + ".cif"
            else:
                return which_file + ".pdb"
        elif os.path.exists(os.path.join(self.args.workdir, "{0}.pdb".format(protein_id))):
            return os.path.join(self.args.workdir, "{0}.pdb".format(protein_id))
        else:
            print("Downloading File : {0}".format(protein_id))
            download_url = "https://files.rcsb.org/download/{0}.pdb".format(protein_id)
            response = urllib2.urlopen(download_url)
            output_file = os.path.join(self.args.workdir, "{0}.pdb".format(protein_id))
            with open(output_file, mode='w') as output:
                output.write(response.read())
            print("Downloaded.")
            return output_file


    def get_phi(self, previous, source_res):
       try:
           C_1 = previous['C'].get_vector()
           N = source_res['N'].get_vector()
           CA = source_res['CA'].get_vector()
           C = source_res['C'].get_vector()
           return degrees(calc_dihedral(C_1, N, CA, C))
       except Exception as e:
           return 0.0

    def get_psi(self, target_res,next_res):
       try:
           N = target_res['N'].get_vector()
           CA = target_res['CA'].get_vector()
           C = target_res['C'].get_vector()
           N1_1 = next_res['N'].get_vector()
           return degrees(calc_dihedral(N, CA, C, N1_1))
       except Exception as e:
           return 0.0

    def __get_elenated_topology(self, source_residues, target_residues):
        """
        This method will calculate the elenated topology mean square deviation
        :param source_residues: Fragment source residues
        :param target_residues: Target protein residues
        :return: Elenated T(msd) Value between these two proteins
        """
        # target residues should be longer than the source residues in most of cases
        if len(source_residues) > len(target_residues):
            return 0.0
        deviation = 0.0
        total = 0.0
        for index, res in enumerate(source_residues[:-1]):
            if index == 0:
                continue
            source_res = source_residues[index]
            target_res = target_residues[index]
            # calculate Phi and Psi torsional angles for the current residues
            source_phi = self.get_phi(source_residues[index - 1], source_res)
            target_phi = self.get_phi(target_residues[index - 1], target_res)
            source_psi = self.get_psi(source_res,source_residues[index+1])
            target_psi = self.get_psi(target_res,target_residues[index+1])
            deviation += sqrt((((source_phi - target_phi) / abs(source_phi + target_phi)) ** 2) + ((source_psi - target_psi) / abs(source_psi + target_psi)) ** 2)
            total += 1
        if total == 0:
            return 0.0
        else:

            return deviation / (float(total) * 100.0)

    def calculate_RMSD(self,row,source_position,fragment_length,aa_only=False):
        try:
            if self.args.model is None:
                setattr(row, "rmsd", -1)
            target_position = row.pos
            source_structure = self.__get_structure__(self.args.model)
            builder = PPBuilder()
            type1 = builder.build_peptides(source_structure, aa_only=aa_only)
            length1 = type1[-1][-1].get_full_id()[3][1]
            fixed_residues = []
            for pp in type1:
                fixed_residues += [x for x in pp]
            fixed = [atom['CA'] for atom in fixed_residues][source_position:source_position + fragment_length]
            builder = PPBuilder()
            target_file = self.get_target_file(row.protein_id)
            if target_file is None:
                setattr(row, "rmsd", -1)
                return
            target_structure = self.__get_structure__(target_file)
            type2 = builder.build_peptides(target_structure, aa_only=aa_only)
            length2 = type2[-1][-1].get_full_id()[3][1]
            moving_residues = []
            for pp in type2:
                moving_residues += [x for x in pp]
            moving = [atom['CA'] for atom in moving_residues][target_position:target_position + fragment_length]
            lengths = [length1, length2]
            smallest = min(int(item) for item in lengths)
            # find RMSD
            if len(fixed) != len(moving):
                setattr(row, "rmsd", -1)
                return
            sup = Bio.PDB.Superimposer()
            sup.set_atoms(fixed, moving)
            sup.apply(target_structure[0].get_atoms())
            RMSD = round(sup.rms, 4)
            setattr(row, "rmsd", RMSD)
        except Exception as e:
            print(e.message)
            setattr(row, "rmsd", -1)

    def get_enhanced(self, ss, pattern):
        sequence = ""
        for index, letter in enumerate(ss):
            sasa = pattern[index]
            if int(sasa) == 0:
                if letter == 'H':
                    sequence += 'I'
                elif letter == 'B':
                    sequence += 'J'
                elif letter == 'E':
                    sequence += 'K'
                elif letter == 'G':
                    sequence += 'L'
                elif letter == 'I':
                    sequence += 'M'
                elif letter == 'T':
                    sequence += 'N'
                elif letter == 'S':
                    sequence += 'O'
                else:
                    sequence += 'P'
            else:
                if letter == 'H':
                    sequence += 'A'
                elif letter == 'B':
                    sequence += 'B'
                elif letter == 'E':
                    sequence += 'C'
                elif letter == 'G':
                    sequence += 'D'
                elif letter == 'I':
                    sequence += 'E'
                elif letter == 'T':
                    sequence += 'F'
                elif letter == 'S':
                    sequence += 'G'
                else:
                    sequence += 'H'
        return sequence

    def search_by_common_pattern(self, pattern):
        if not self.db:
            print("Couldn't connect to the database.")
            self.p.print_help()
            return
        if self.args.fuzzy:
            query = "select p.* , ngramSearch(enhanced,'{0}') as ngram from proteins as p where ngram > {2} order by ngram DESC limit {3}".format(pattern,pattern,self.args.fuzzylevel,self.args.limit
                                                                                                                                                   )
        else:
            query = "select p.* , position(enhanced,'{0}') as pos from proteins as p where position(enhanced,'{1}') > 0 limit {2}".format(
            pattern, pattern,self.args.limit)
        rows = []
        headers = ["protein_id"]
        if not self.args.fuzzy:
            headers += ["pos"]
        else:
            headers += ["ngram"]
        for row in self.db.select(query):
            rows.append(row)
        if len(rows) > 0:
            self.print_results(headers, rows)
        return rows

    def print_results(self, headers, rows):
        limit = self.args.limit if self.args.limit < len(rows) else len(rows)
        data = [headers] + [[getattr(row, x) for x in headers if hasattr(row,x)] for row in rows]
        table = AsciiTable(data)
        table.title = "Possible Matche(s)"
        table.inner_row_border = True
        print(table.table)
        output = 'Total : {0}'.format(len(data) - 1)
        print(output)
Exemple #31
0
class DecimalFieldsTest(unittest.TestCase):

    def setUp(self):
        self.database = Database('test-db', log_statements=True)
        try:
            self.database.create_table(DecimalModel)
        except ServerError as e:
            # This ClickHouse version does not support decimals yet
            raise unittest.SkipTest(e.message)

    def tearDown(self):
        self.database.drop_database()

    def _insert_sample_data(self):
        self.database.insert([
            DecimalModel(date_field='2016-08-20'),
            DecimalModel(date_field='2016-08-21', dec=Decimal('1.234')),
            DecimalModel(date_field='2016-08-22', dec32=Decimal('12342.2345')),
            DecimalModel(date_field='2016-08-23', dec64=Decimal('12342.23456')),
            DecimalModel(date_field='2016-08-24', dec128=Decimal('-4545456612342.234567')),
        ])

    def _assert_sample_data(self, results):
        self.assertEqual(len(results), 5)
        self.assertEqual(results[0].dec, Decimal(0))
        self.assertEqual(results[0].dec32, Decimal(17))
        self.assertEqual(results[1].dec, Decimal('1.234'))
        self.assertEqual(results[2].dec32, Decimal('12342.2345'))
        self.assertEqual(results[3].dec64, Decimal('12342.23456'))
        self.assertEqual(results[4].dec128, Decimal('-4545456612342.234567'))

    def test_insert_and_select(self):
        self._insert_sample_data()
        query = 'SELECT * from $table ORDER BY date_field'
        results = list(self.database.select(query, DecimalModel))
        self._assert_sample_data(results)

    def test_ad_hoc_model(self):
        self._insert_sample_data()
        query = 'SELECT * from decimalmodel ORDER BY date_field'
        results = list(self.database.select(query))
        self._assert_sample_data(results)

    def test_rounding(self):
        d = Decimal('11111.2340000000000000001')
        self.database.insert([DecimalModel(date_field='2016-08-20', dec=d, dec32=d, dec64=d, dec128=d)])
        m = DecimalModel.objects_in(self.database)[0]
        for val in (m.dec, m.dec32, m.dec64, m.dec128):
            self.assertEqual(val, Decimal('11111.234'))

    def test_assignment_ok(self):
        for value in (True, False, 17, 3.14, '20.5', Decimal('20.5')):
            DecimalModel(dec=value)

    def test_assignment_error(self):
        for value in ('abc', u'זה ארוך', None, float('NaN'), Decimal('-Infinity')):
            with self.assertRaises(ValueError):
                DecimalModel(dec=value)

    def test_aggregation(self):
        self._insert_sample_data()
        result = DecimalModel.objects_in(self.database).aggregate(m='min(dec)', n='max(dec)')
        self.assertEqual(result[0].m, Decimal(0))
        self.assertEqual(result[0].n, Decimal('1.234'))

    def test_precision_and_scale(self):
        # Go over all valid combinations
        for precision in range(1, 39):
            for scale in range(0, precision + 1):
                f = DecimalField(precision, scale)
        # Some invalid combinations
        for precision, scale in [(0, 0), (-1, 7), (7, -1), (39, 5), (20, 21)]:
            with self.assertRaises(AssertionError):
                f = DecimalField(precision, scale)

    def test_min_max(self):
        # In range
        f = DecimalField(3, 1)
        f.validate(f.to_python('99.9', None))
        f.validate(f.to_python('-99.9', None))
        # In range after rounding
        f.validate(f.to_python('99.94', None))
        f.validate(f.to_python('-99.94', None))
        # Out of range
        with self.assertRaises(ValueError):
            f.validate(f.to_python('99.99', None))
        with self.assertRaises(ValueError):
            f.validate(f.to_python('-99.99', None))
        # In range
        f = Decimal32Field(4)
        f.validate(f.to_python('99999.9999', None))
        f.validate(f.to_python('-99999.9999', None))
        # In range after rounding
        f.validate(f.to_python('99999.99994', None))
        f.validate(f.to_python('-99999.99994', None))
        # Out of range
        with self.assertRaises(ValueError):
            f.validate(f.to_python('100000', None))
        with self.assertRaises(ValueError):
            f.validate(f.to_python('-100000', None))
class EnginesTestCase(unittest.TestCase):
    def setUp(self):
        self.database = Database('test-db')

    def tearDown(self):
        self.database.drop_database()

    def _create_and_insert(self, model_class):
        self.database.create_table(model_class)
        self.database.insert([
            model_class(date='2017-01-01',
                        event_id=23423,
                        event_group=13,
                        event_count=7,
                        event_version=1)
        ])

    def test_merge_tree(self):
        class TestModel(SampleModel):
            engine = MergeTree('date', ('date', 'event_id', 'event_group'))

        self._create_and_insert(TestModel)

    def test_merge_tree_with_sampling(self):
        class TestModel(SampleModel):
            engine = MergeTree('date', ('date', 'event_id', 'event_group'),
                               sampling_expr='intHash32(event_id)')

        self._create_and_insert(TestModel)

    def test_merge_tree_with_granularity(self):
        class TestModel(SampleModel):
            engine = MergeTree('date', ('date', 'event_id', 'event_group'),
                               index_granularity=4096)

        self._create_and_insert(TestModel)

    def test_replicated_merge_tree(self):
        engine = MergeTree(
            'date', ('date', 'event_id', 'event_group'),
            replica_table_path='/clickhouse/tables/{layer}-{shard}/hits',
            replica_name='{replica}')
        expected = "ReplicatedMergeTree('/clickhouse/tables/{layer}-{shard}/hits', '{replica}', date, (date, event_id, event_group), 8192)"
        self.assertEquals(engine.create_table_sql(), expected)

    def test_collapsing_merge_tree(self):
        class TestModel(SampleModel):
            engine = CollapsingMergeTree('date',
                                         ('date', 'event_id', 'event_group'),
                                         'event_version')

        self._create_and_insert(TestModel)

    def test_summing_merge_tree(self):
        class TestModel(SampleModel):
            engine = SummingMergeTree('date', ('date', 'event_group'),
                                      ('event_count', ))

        self._create_and_insert(TestModel)

    def test_replacing_merge_tree(self):
        class TestModel(SampleModel):
            engine = ReplacingMergeTree('date',
                                        ('date', 'event_id', 'event_group'),
                                        'event_uversion')

        self._create_and_insert(TestModel)

    def test_tiny_log(self):
        class TestModel(SampleModel):
            engine = TinyLog()

        self._create_and_insert(TestModel)

    def test_log(self):
        class TestModel(SampleModel):
            engine = Log()

        self._create_and_insert(TestModel)

    def test_memory(self):
        class TestModel(SampleModel):
            engine = Memory()

        self._create_and_insert(TestModel)

    def test_merge(self):
        class TestModel1(SampleModel):
            engine = TinyLog()

        class TestModel2(SampleModel):
            engine = TinyLog()

        class TestMergeModel(MergeModel, SampleModel):
            engine = Merge('^testmodel')

        self.database.create_table(TestModel1)
        self.database.create_table(TestModel2)
        self.database.create_table(TestMergeModel)

        # Insert operations are restricted for this model type
        with self.assertRaises(DatabaseException):
            self.database.insert([
                TestMergeModel(date='2017-01-01',
                               event_id=23423,
                               event_group=13,
                               event_count=7,
                               event_version=1)
            ])

        # Testing select
        self.database.insert([
            TestModel1(date='2017-01-01',
                       event_id=1,
                       event_group=1,
                       event_count=1,
                       event_version=1)
        ])
        self.database.insert([
            TestModel2(date='2017-01-02',
                       event_id=2,
                       event_group=2,
                       event_count=2,
                       event_version=2)
        ])
        # event_uversion is materialized field. So * won't select it and it will be zero
        res = self.database.select(
            'SELECT *, event_uversion FROM $table ORDER BY event_id',
            model_class=TestMergeModel)
        res = [row for row in res]
        self.assertEqual(2, len(res))
        self.assertDictEqual(
            {
                '_table': 'testmodel1',
                'date': datetime.date(2017, 1, 1),
                'event_id': 1,
                'event_group': 1,
                'event_count': 1,
                'event_version': 1,
                'event_uversion': 1
            }, res[0].to_dict(include_readonly=True))
        self.assertDictEqual(
            {
                '_table': 'testmodel2',
                'date': datetime.date(2017, 1, 2),
                'event_id': 2,
                'event_group': 2,
                'event_count': 2,
                'event_version': 2,
                'event_uversion': 2
            }, res[1].to_dict(include_readonly=True))