示例#1
0
 def test_optimize_always_false_list_conditions(self):
     condition1 = Condition(Column('column', 'tbl'), ShardingOperator.IN,
                            SQLNumberExpression(1), SQLNumberExpression(2))
     condition2 = Condition(Column('column', 'tbl'), ShardingOperator.EQUAL,
                            SQLNumberExpression(3))
     and_condition = AndCondition()
     and_condition.conditions.extend([condition1, condition2])
     or_condition = OrCondition()
     or_condition.and_conditions.append(and_condition)
     sharding_conditions = QueryOptimizeEngine(or_condition, []).optimize()
     self.assertTrue(sharding_conditions.is_always_false())
示例#2
0
 def test_optimize_list_conditions(self):
     condition1 = Condition(Column('column', 'tbl'), ShardingOperator.IN,
                            SQLNumberExpression(1), SQLNumberExpression(2))
     condition2 = Condition(Column('column', 'tbl'), ShardingOperator.EQUAL,
                            SQLNumberExpression(1))
     and_condition = AndCondition()
     and_condition.conditions.extend([condition1, condition2])
     or_condition = OrCondition()
     or_condition.and_conditions.append(and_condition)
     sharding_conditions = QueryOptimizeEngine(or_condition, []).optimize()
     self.assertFalse(sharding_conditions.is_always_false())
     self.assertEqual(len(sharding_conditions.sharding_conditions), 1)
     self.assertEqual(
         len(sharding_conditions.sharding_conditions[0].sharding_values), 1)
     sharding_value = sharding_conditions.sharding_conditions[
         0].sharding_values[0]
     self.assertTrue(isinstance(sharding_value, ListShardingValue))
     self.assertEqual(sharding_value.values, [1])
示例#3
0
 def _remove_generate_key_column(self, insert_statement, values_count):
     generate_key_column = self.sharding_rule.get_generate_key_column(
         insert_statement.tables.get_single_table_name())
     if generate_key_column and values_count < len(insert_statement.columns):
         insert_statement.columns.remove(
             Column(generate_key_column.name, insert_statement.tables.get_single_table_name()))
         for each in insert_statement.get_items_tokens():
             each.items.remove(generate_key_column.name)
             insert_statement.generate_key_column_index = -1
示例#4
0
 def parse(self, insert_statement, sharding_meta_data):
     result = list()
     table_name = insert_statement.tables.get_single_table_name()
     generated_key_column = self.sharding_rule.get_generate_key_column(table_name)
     count = 0
     if self.lexer_engine.equal_any(Symbol.LEFT_PAREN):
         while True:
             self.lexer_engine.next_token()
             column_name = sqlutil.get_exactly_value(self.lexer_engine.get_current_token().literals)
             result.append(Column(column_name, table_name))
             self.lexer_engine.next_token()
             if generated_key_column and strutil.equals_ignore_case(generated_key_column.name, column_name):
                 insert_statement.generate_key_column_index = count
             count += 1
             if self.lexer_engine.equal_any(Symbol.RIGHT_PAREN) or self.lexer_engine.equal_any(Assist.END):
                 break
         insert_statement.columns_list_last_position = self.lexer_engine.get_current_token().end_position - len(
             self.lexer_engine.get_current_token().literals)
         self.lexer_engine.next_token()
     else:
         column_names = sharding_meta_data.table_meta_data_map.get(table_name).get_all_column_names()
         begin_position = self.lexer_engine.get_current_token().end_position - len(
             self.lexer_engine.get_current_token().literals) - 1
         insert_statement.sql_tokens.append(InsertColumnToken(begin_position, '('))
         columns_token = ItemsToken(begin_position)
         columns_token.is_first_of_items_special = True
         for column_name in column_names:
             result.append(Column(column_name, table_name))
             if generated_key_column and strutil.equals_ignore_case(generated_key_column.name, column_name):
                 insert_statement.generate_key_column_index = count
             columns_token.items.append(column_name)
             count += 1
         insert_statement.sql_tokens.append(columns_token)
         insert_statement.sql_tokens.append(InsertColumnToken(begin_position, ')'))
         insert_statement.columns_list_last_position = begin_position
     insert_statement.columns.extend(result)
示例#5
0
    def parse(self, insert_statement):
        if not self.lexer_engine.skip_if_equal(*self.get_customized_insert_keywords()):
            return

        self.lexer_engine.accept(DefaultKeyword.DUPLICATE)
        self.lexer_engine.accept(DefaultKeyword.KEY)
        self.lexer_engine.accept(DefaultKeyword.UPDATE)
        while True:
            column = Column(self.lexer_engine.get_current_token().literals,
                            insert_statement.tables.get_single_table_name())
            if self.sharding_rule.is_sharding_column(column):
                raise SQLParsingException(
                    'INSERT INTO .... ON DUPLICATE KEY UPDATE can not support on sharding column: {}'.format(
                        column.name))
            self.lexer_engine.skip_until(Symbol.COMMA, Assist.END)
            if not self.lexer_engine.skip_if_equal(Symbol.COMMA):
                break
示例#6
0
 def _get_column_without_owner(self, tables, identifier_expression):
     if tables.is_single_table():
         return Column(sqlutil.get_exactly_value(identifier_expression.name), tables.get_single_table_name())
示例#7
0
 def _get_column_with_owner(self, tables, property_expression):
     table = tables.find(sqlutil.get_exactly_value(property_expression.owner.name))
     if table:
         return Column(sqlutil.get_exactly_value(property_expression.name), table.name)
示例#8
0
 def get_generate_key_column(self, logic_table_name):
     for each in self.table_rules:
         if strutil.equals_ignore_case(
                 logic_table_name,
                 each.logic_table) and each.generate_key_column:
             return Column(each.generate_key_column, logic_table_name)