示例#1
0
 def execute_query(cls, db_root, query, database_schemas, silent=False):
     db_id = query['db_id']
     db_path = os.path.join(db_root, db_id, db_id + '.sqlite')
     query_recov = query['recov']
     if 't5' in query_recov:
         return None
     query['query_toks'] = query_recov.replace('.', ' . ').split()
     query['query_toks_no_value'] = editsql_postprocess.postprocess_one(' '.join(query['column_mapped']), database_schemas[db_id]).replace('limit 1', 'limit_value').replace(' 1', ' value').replace('.', ' . ').split(' ')
     schema = evaluation.Schema(evaluation.get_schema(db_path))
     g_raw_res = timed_execute(db_path, query_recov, timeout=1, sleep=0.001, silent=silent)
     return g_raw_res
示例#2
0
    def compute_upperbound(self, dev):
        preds = {}
        for ex in dev:
            toks = preprocess.SQLDataset.recover_query(ex['pointer_query'],
                                                       ex['cands_query'],
                                                       voc=self.sql_vocab)
            post = editsql_postprocess.postprocess_one(
                ' '.join(toks), self.database_schemas[ex['db_id']])

            preds[ex['id']] = dict(
                query_pointer=ex['pointer_query'],
                query_toks=toks,
                query=post,
            )
        self.eval()
        metrics = self.compute_metrics(dev, preds)
        return metrics
示例#3
0
    def compute_upperbound(self, dev):
        preds = {}
        for ex in dev:
            toks, value_toks = preprocess.SQLDataset.recover_query(
                ex['pointer_query'],
                ex['cands_query'],
                ex['pointer_value'],
                ex['cands_value'],
                voc=self.sql_vocab)
            post = editsql_postprocess.postprocess_one(
                ' '.join(toks), self.database_schemas[ex['db_id']])
            values = '___'.join(value_toks).split('SEP')
            values = [
                self.bert_tokenizer.convert_tokens_to_string(t.split('___'))
                for t in values if t
            ]

            # apply fix
            for i, v in enumerate(values):
                words = v.split()
                fixed = [preprocess.value_replace.get(w, w) for w in words]
                values[i] = ' '.join(fixed)

            if self.args.keep_values:
                post = self.postprocess_value(post, values)

            utt_toks = preprocess.SQLDataset.recover_question(
                ex['pointer_question'],
                ex['cands_question'],
                voc=self.utt_vocab)
            utt = ' '.join(utt_toks)

            preds[ex['id']] = dict(
                query_pointer=ex['pointer_query'],
                query_toks=toks,
                query=post,
                utt_toks=utt_toks,
                utt=utt,
                value_toks=value_toks,
                values=values,
            )
        self.eval()
        metrics = self.compute_metrics(dev, preds)
        return metrics
示例#4
0
    def extract_preds(self, out, feat, batch):
        preds = {}
        for pointer, value_pointer, utt_pointer, ex in zip(
                out['query_dec'].max(2)[1].tolist(),
                out['value_dec'].max(2)[1].tolist(),
                out['utt_dec'].max(2)[1].tolist(), batch):
            toks, value_toks = preprocess.SQLDataset.recover_query(
                pointer,
                ex['cands_query'],
                value_pointer,
                ex['cands_value'],
                voc=self.sql_vocab)
            post = toks[:]
            values = '___'.join(value_toks).split('SEP')
            values = [
                self.bert_tokenizer.convert_tokens_to_string(t.split('___'))
                for t in values if t
            ]
            schema = self.database_schemas[ex['db_id']]
            try:
                post = editsql_postprocess.postprocess_one(
                    ' '.join(post), schema)
                if self.args.keep_values:
                    post = self.postprocess_value(post, values)
            except Exception as e:
                post = repr(e)

            utt_toks = preprocess.SQLDataset.recover_question(
                utt_pointer, ex['cands_question'], voc=self.utt_vocab)
            utt = self.bert_tokenizer.convert_tokens_to_string(
                utt_toks).replace(' id', '')
            preds[ex['id']] = dict(
                query_pointer=pointer,
                query_toks=toks,
                query=post,
                value_pointer=value_pointer,
                value_toks=value_toks,
                values=values,
                utt_pointer=utt_pointer,
                utt_toks=utt_toks,
                utt=utt,
            )
        return preds
示例#5
0
    def extract_preds(self, out, feat, batch):
        preds = {}
        for pointer, ex in zip(out['query_dec'].max(2)[1].tolist(), batch):
            toks = preprocess.SQLDataset.recover_query(pointer,
                                                       ex['cands_query'],
                                                       voc=self.sql_vocab)
            post = toks[:]
            schema = self.database_schemas[ex['db_id']]
            try:
                post = post_no_value = editsql_postprocess.postprocess_one(
                    ' '.join(post), schema)
            except Exception as e:
                post = post_no_value = repr(e)

            preds[ex['id']] = dict(
                query_pointer=pointer,
                query_toks=toks,
                query_no_value=post_no_value,
                query=post,
            )
        return preds
示例#6
0
    def sample_one(cls, db_root, stats, db_ids, schema_tokens, column_names,
                   database_schemas, proc_cols, supports):
        # sample a template
        templates = sorted(list(stats.keys()))
        template_probs = [stats[k] for k in templates]

        template = np.random.choice(templates, p=template_probs)

        ex = dict(template=template, p_template=stats[template])

        # gather databases that can support this template
        vcols = {t for t in template.split() if '_col_' in t}
        supported_dbs = []
        for db_id, supported_cols in supports.items():
            if not vcols - supported_cols:
                # every col is supported
                supported_dbs.append(db_id)
        supported_dbs.sort()
        if not supported_dbs:
            raise Exception(
                'Could not find valid database for {}'.format(vcols))
        ex['num_supported_db'] = len(supported_dbs)
        ex['p_supported_db'] = 1 / len(supported_dbs)

        # sample a database at random
        db_id = np.random.choice(supported_dbs)
        db_path = os.path.join(db_root, db_id, db_id + '.sqlite')
        cols = proc_cols[db_id]

        # make a random mapping of columns
        cols_shuffled = cols[:]
        np.random.shuffle(cols_shuffled)
        col_map = {c['key']: c for c in proc_cols[db_id]}
        col_map['*'] = dict(name='*', table_name='', key='*', type='*')
        mapping = {}
        for c in cols_shuffled:
            t = cls.get_coarse_type(c['type'])
            idx = sum(x.startswith(t) for x in mapping.keys())
            mapping['{}_col_{}'.format(t, idx)] = c['key']

        # map columns
        coarse_column_mapped_query = [
            mapping.get(t, t).lower() for t in template.split()
        ]

        column_mapped_query = cls.detemplatize(col_map,
                                               coarse_column_mapped_query)
        # print(coarse_column_mapped_query)
        # print(column_mapped_query)
        # import pdb; pdb.set_trace()

        # map values
        value_mapped_query = []
        last_col = None
        last_op = None
        query = ''
        for i, t in enumerate(column_mapped_query):
            if t in col_map:
                last_col = col_map[t]
                if i - 2 >= 0 and column_mapped_query[i - 1] == '(':
                    last_op = column_mapped_query[i - 2]
                else:
                    last_op = None
            if t == 'value':
                val, val_toks = cls.sample_value(last_col, last_op, db_path)
                query += ' ' + val
                value_mapped_query.extend(val_toks)
            elif t == 'limit_value':
                query += ' limit 1'
                value_mapped_query.extend(['limit', '1'])
            else:
                query += ' ' + t
                value_mapped_query.append(t)

        ex['recov'] = editsql_postprocess.postprocess_one(
            query, database_schemas[db_id])
        ex['recov'] = re.sub('\s([0-9a-zA-Z_]+\.)\*\s', '', ex['recov'])
        ex['values'] = SQLDataset.align_values(column_mapped_query,
                                               value_mapped_query)

        ex['column_mapped'] = column_mapped_query
        ex['value_mapped'] = value_mapped_query
        ex['norm_query'] = query.strip().split()
        ex['db_id'] = db_id
        return ex
示例#7
0
    def sample_one(self, db_root, stats, allowed_db_ids, proc_cols, supports):
        p_prevs, p_curr_given_prevs = stats

        # sample templates
        templates = sorted(list(p_prevs.keys()))
        template_probs = [p_prevs[k] for k in templates]
        prev_template = np.random.choice(templates, p=template_probs)

        templates = sorted(list(p_curr_given_prevs[prev_template].keys()))
        template_probs = [
            p_curr_given_prevs[prev_template][k] for k in templates
        ]
        curr_template = np.random.choice(templates, p=template_probs)

        ex = dict(
            prev_template=prev_template,
            p_prev_template=p_prevs[prev_template],
            curr_template=curr_template,
            p_curr_template=p_curr_given_prevs[prev_template][curr_template])

        # gather databases that can support this template
        vcols = {
            t
            for t in prev_template.split() + curr_template.split()
            if '_col_' in t
        }
        supported_dbs = []
        for db_id, supported_cols in supports.items():
            if db_id in allowed_db_ids and not vcols - supported_cols:
                # every col is supported
                supported_dbs.append(db_id)
        if not supported_dbs:
            raise NoSupportedDBException(curr_template)
        supported_dbs.sort()
        ex['num_supported_db'] = len(supported_dbs)
        ex['p_supported_db'] = 1 / len(supported_dbs)

        # sample a database at random
        db_id = np.random.choice(supported_dbs)
        db_path = os.path.join(db_root, db_id, db_id + '.sqlite')
        cols = proc_cols[db_id]

        # make a random mapping of columns
        cols_shuffled = cols[:]
        np.random.shuffle(cols_shuffled)
        col_map = {c['key']: c for c in proc_cols[db_id]}
        col_map['*'] = '*'
        mapping = {}
        for c in cols_shuffled:
            idx = sum(x.startswith(c['type']) for x in mapping.keys())
            mapping['{}_col_{}'.format(c['type'], idx)] = c['key']

        # insert mapping map into query
        prev_column_mapped_query, prev_value_mapped_query = self.fill_template(
            prev_template, mapping, col_map, db_path)
        curr_column_mapped_query, curr_value_mapped_query = self.fill_template(
            curr_template, mapping, col_map, db_path)

        ex['recov'] = editsql_postprocess.postprocess_one(
            ' '.join(curr_column_mapped_query),
            self.conv.database_schemas[db_id])
        ex['recov'] = re.sub('\s([0-9a-zA-Z_]+\.)\*\s', '', ex['recov'])
        ex['values'] = preprocess_nl2sql.SQLDataset.align_values(
            curr_column_mapped_query, curr_value_mapped_query)

        ex['prev_column_mapped'] = prev_column_mapped_query
        ex['prev_value_mapped'] = prev_value_mapped_query
        ex['curr_column_mapped'] = curr_column_mapped_query
        ex['curr_value_mapped'] = curr_value_mapped_query
        ex['norm_query'] = curr_column_mapped_query
        ex['db_id'] = db_id
        return ex
示例#8
0
    def sample_one(cls, db_root, stats, allowed_db_ids, schema_tokens, column_names, database_schemas, proc_cols, supports):
        # sample a database
        all_db_ids = sorted(allowed_db_ids)
        db_id = np.random.choice(all_db_ids)
        db_path = os.path.join(db_root, db_id, db_id + '.sqlite')
        cols = proc_cols[db_id]

        # sample a template
        supported_cols = supports[db_id]
        supported_templates = []
        for template in stats.keys():
            vcols = {t for t in template.split() if '_col_' in t}
            if not vcols - supported_cols:
                supported_templates.append(template)
        supported_templates.sort()
        supported_templates_p = [stats[k] for k in supported_templates]
        partition = sum(supported_templates_p)
        supported_templates_p = [p/partition for p in supported_templates_p]

        if not supported_templates:
            raise NoSupportedDBException()

        template = np.random.choice(supported_templates, p=supported_templates_p)
    
        ex = dict(template=template, p_template=stats[template])
    
        # make a random mapping of columns
        cols_shuffled = cols[:]
        np.random.shuffle(cols_shuffled)
        col_map = {c['key']: c for c in proc_cols[db_id]}
        col_map['*'] = '*'
        mapping = {}
        for c in cols_shuffled:
            idx = sum(x.startswith(c['type']) for x in mapping.keys())
            mapping['{}_col_{}'.format(c['type'], idx)] = c['key']
    
        # insert mapping map into query
        column_mapped_query = [mapping.get(t, t).lower() for t in template.split()]
        value_mapped_query = []
        last_col = None
        last_op = None
        query = ''
        for i, t in enumerate(column_mapped_query):
            if t in col_map:
                last_col = col_map[t]
                if i-2 >= 0 and column_mapped_query[i-1] == '(':
                    last_op = column_mapped_query[i-2]
                else:
                    last_op = None
            if t == 'value':
                val, val_toks = cls.sample_value(last_col, last_op, db_path)
                query += ' ' + val
                value_mapped_query.extend(val_toks)
            elif t == 'limit_value':
                query += ' limit 1'
                value_mapped_query.extend(['limit', '1'])
            else:
                query += ' ' + t
                value_mapped_query.append(t)
    
        ex['recov'] = editsql_postprocess.postprocess_one(query, database_schemas[db_id])
        ex['recov'] = re.sub('\s([0-9a-zA-Z_]+\.)\*\s', '', ex['recov'])
        ex['values'] = SQLDataset.align_values(column_mapped_query, value_mapped_query)
    
        ex['column_mapped'] = column_mapped_query
        ex['value_mapped'] = value_mapped_query
        ex['norm_query'] = query.strip().split()
        ex['db_id'] = db_id
        return ex
示例#9
0
文件: converter.py 项目: vzhong/gazp
 def recover(self, converted, db_id):
     db = self.database_schemas[db_id]
     query_recov = editsql_postprocess.postprocess_one(converted, db)
     return query_recov