def create_temp_table(self, params: dict): """ This method creates the table that will store the regression parameters (i.e., weights and bias) :param params: the sdca regressor params :return: sql query containing the statement for the regression table creation """ params = SDCARegressorSQL.check_params(params) create_stm = f'DROP TABLE IF EXISTS {self.temp_table_name};\n' create_stm += f'CREATE TABLE {self.temp_table_name} ({self.temp_table_pk} int, {self.temp_table_weight_col} float);\n' insert_stm = f'INSERT INTO {self.temp_table_name} VALUES (\n' for widx, w in enumerate(params['weights']): if self.dbms == 'sqlserver': if widx > 0 and widx % 1000 == 0: insert_stm = insert_stm[:-2] # remove ',\n' insert_stm += ';\n\n' insert_stm += f'INSERT INTO {self.temp_table_name} VALUES\n' insert_stm += f'({widx}, {w}),\n' insert_stm = insert_stm[-2:] # remove ',\n' insert_stm += ');\n' index_name = f'{self.temp_table_name}_{self.temp_table_pk}' index_stm = DBMSUtils.create_index(self.dbms, index_name, self.temp_table_name, self.temp_table_pk) # query = f'{create_stm}{insert_stm}{index_stm}' return [create_stm, insert_stm, index_stm]
def create_temp_table(self, params: dict): """ This method creates the table that will store the logistic regression parameters (i.e., weights and bias) :param params: the logistic regression params :return: sql query containing the statement for the logistic regression table creation """ params = LogisticRegressionSQL.check_params(params) query_list = [] # query = '' for class_idx in range(len(params['weights'])): tab_name = f'{self.temp_table_name}_{class_idx}' class_weights = params['weights'][class_idx] create_stm = f'DROP TABLE IF EXISTS {tab_name};\n' create_stm += f'CREATE TABLE {tab_name} ({self.temp_table_pk} int, {self.temp_table_weight_col} float);\n' insert_stm = f'INSERT INTO {tab_name} VALUES\n' for widx, w in enumerate(class_weights): if self.dbms == 'sqlserver': if widx > 0 and widx % 1000 == 0: insert_stm = insert_stm[:-2] # remove ',\n' insert_stm += ';\n\n' insert_stm += f'INSERT INTO {tab_name} VALUES\n' insert_stm += f'({widx}, {w}),\n' insert_stm = insert_stm[:-2] # remove ',\n' insert_stm += ';\n' index_name = f'{tab_name}_{self.temp_table_pk}' index_stm = DBMSUtils.create_index(self.dbms, index_name, tab_name, self.temp_table_pk) # query += f'{create_stm}{insert_stm}{index_stm}' query_list += [create_stm, insert_stm, index_stm] return query_list
def _get_query_sparse_ohe(self, ohe_params: dict): """ This method creates an SQL query that implements a sparse ohe transformation. The query is composed by a main CASE statement that replicates the OHE mapping. For high dimensional data, it is not possible to encode the mapping inside a single CASE statement, because of the limit of the number of WHEN statements that can be inserted in a query. In this case multiple queries are generated and for each of them the maximum number of WHEN statements allowed is considered. Each query result is saved into a temporary table. :param ohe_params: dictionary containing the parameters extracted from the fitted OneHotEncoder :return: the SQL query that implements the sparse One Hot Encoding transformation """ ohe_params = OneHotEncoderSQL.check_ohe_params(ohe_params) ohe_feature_map = ohe_params['ohe2idx_map'] original_ohe_features = ohe_params['ohe_features'] ohe_query = "" # considering that each feature after the ohe is used to create a WHEN statement, it is needed to check if this # number if greater (or not) than the maximum number of WHEN statements that can be included in a single CASE # statement. For SQLSERVER the maximum number of WHEN statements is 9700. # https://www.sqlservercentral.com/forums/topic/maximum-number-of-when-then-lines-in-a-case-statement sql_max_when_statements = 9700 # sql_max_when_statements = 100 # if the OHE is applied directly on the original table then a temporary table is created to store OHE # results in a triplet data format warn_message = "A temporary table 'ohe_table' will be created." logging.warning(warn_message) # add to the ohe query the SQL statement for the creation of the intermediate ohe table create_ohe_table_query = f"DROP TABLE IF EXISTS {self.ohe_table_name};\n" create_ohe_table_query += f"CREATE TABLE {self.ohe_table_name}({self.ohe_table_pk} int, " create_ohe_table_query += f"{self.ohe_table_fval_col} float, {self.ohe_table_fidx_col} int);\n\n" # create_ohe_table_query += f" PRIMARY KEY({self.ohe_table_pk}, {self.ohe_table_fidx_col}));\n\n" # ohe_query += create_ohe_table_query # split, if needed, the OHEed features in batches smaller than the SQL limits ohe_feature_map_batches = [] num_batch = 1 # loop over OHEed columns for col in ohe_feature_map: feature_map = ohe_feature_map[col] num_ohe_features_per_col = len(feature_map) # check if the number of features derived from the current OHEed column is greater than the DBMS limits if num_ohe_features_per_col > sql_max_when_statements: # split the query in multiple batch queries batch_size = sql_max_when_statements if num_ohe_features_per_col % batch_size == 0: num_batch = num_ohe_features_per_col // batch_size else: num_batch = num_ohe_features_per_col // batch_size + 1 feature_map_vals = list(feature_map.items()) # loop over the number of batch for i in range(num_batch): # select a partition of the features after ohe batch_ohe_feature_map = dict( feature_map_vals[i * batch_size:i * batch_size + batch_size]) ohe_feature_map_batches.append( {col: batch_ohe_feature_map}) else: ohe_feature_map_batches.append({col: feature_map}) # loop over the batches ohe_sub_queries = [] for ohe_feature_map_batch in ohe_feature_map_batches: # create the SQL query that applies the One Hot Encoding on the selected features batch_mode = False if num_batch > 1: batch_mode = True ohe_batch_query = self._create_ohe_query(ohe_feature_map_batch, original_ohe_features, batch_mode=batch_mode) ohe_sub_queries.append(ohe_batch_query) # optimization: combine multiple ohe batch queries to reduce the total number of INSERT statements cum_sum = 0 current_combined_suq_queries = [] for j in range(len(ohe_feature_map_batches)): suq_query = ohe_sub_queries[j] ohe_feature_map_batch = ohe_feature_map_batches[j] ohe_batch_query_size = len(list(ohe_feature_map_batch.values())[0]) cum_sum += ohe_batch_query_size current_combined_suq_queries.append(suq_query) if cum_sum > sql_max_when_statements: cum_sum = ohe_batch_query_size list_joint_sub_queries = current_combined_suq_queries[:-1] current_combined_suq_queries = [ current_combined_suq_queries[-1] ] joint_sub_queries = "" for sub_query in list_joint_sub_queries: joint_sub_queries += "{}\n\n UNION ALL \n\n".format( sub_query[:-3]) # remove ';\n\n' joint_sub_queries = joint_sub_queries[:-15] + ";" # remove '\n\n UNION ALL \n\n' # if multiple batch queries are generated, they have to be saved in a temporary table with an # INSERT statement insert_stm = "" if num_batch > 1: insert_stm += f"INSERT INTO {self.ohe_table_name}\n" else: insert_stm += f"INSERT INTO {self.ohe_table_name}\n" ohe_query += "{}{}\n\n".format(insert_stm, joint_sub_queries) # combine the last ohe sub queries joint_sub_queries = "" for sub_query in current_combined_suq_queries: joint_sub_queries += "{}\n\n UNION ALL \n\n".format( sub_query[:-3]) # remove ';\n\n' joint_sub_queries = joint_sub_queries[:-15] + ";" # remove '\n\n UNION ALL \n\n' # if multiple batch queries are generated, they have to be saved in a temporary table with an # INSERT statement insert_stm = "" if num_batch > 1: insert_stm += f"INSERT INTO {self.ohe_table_name}\n" else: insert_stm += f"INSERT INTO {self.ohe_table_name}\n" ohe_query += "{}{}\n\n".format(insert_stm, joint_sub_queries) # create an index on the ohe table index_ohe = DBMSUtils.create_index( dbms=self.dbms, index_name=f'{self.ohe_table_name}_{self.ohe_table_pk}', target_table=self.ohe_table_name, target_col=self.ohe_table_pk) index_ohe += DBMSUtils.create_index( dbms=self.dbms, index_name=f'{self.ohe_table_name}_{self.ohe_table_fidx_col}', target_table=self.ohe_table_name, target_col=self.ohe_table_fidx_col) # ohe_query += index_ohe return [create_ohe_table_query, ohe_query, index_ohe], f'select * from {self.ohe_table_name}'