示例#1
0
def get_matched_entries(s, field_values, m_theta=0.85, s_theta=0.85):
    if not field_values:
        return None

    if isinstance(s, str):
        n_grams = split(s)
    else:
        n_grams = s

    matched = dict()
    for field_value in field_values:
        if not isinstance(field_value, string_types):
            continue
        fv_tokens = split(field_value)
        sm = difflib.SequenceMatcher(None, n_grams, fv_tokens)
        match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens))
        if match.size > 0:
            source_match = get_effecitve_match_source(n_grams, match.a,
                                                      match.a + match.size)
            if source_match and source_match.size > 1:
                match_str = field_value[match.b:match.b + match.size]
                source_match_str = s[source_match.start:source_match.start +
                                     source_match.size]
                c_match_str = match_str.lower().strip()
                c_source_match_str = source_match_str.lower().strip()
                c_field_value = field_value.lower().strip()
                if c_match_str and not utils.is_number(
                        c_match_str) and not utils.is_common_db_term(
                            c_match_str):
                    if utils.is_stopword(c_match_str) or utils.is_stopword(c_source_match_str) or \
                            utils.is_stopword(c_field_value):
                        continue
                    if c_source_match_str.endswith(c_match_str + '\'s'):
                        match_score = 1.0
                    else:
                        if prefix_match(c_field_value, c_source_match_str):
                            match_score = fuzz.ratio(c_field_value,
                                                     c_source_match_str) / 100
                        else:
                            match_score = 0
                    if (utils.is_commonword(c_match_str)
                            or utils.is_commonword(c_source_match_str)
                            or utils.is_commonword(c_field_value)
                        ) and match_score < 1:
                        continue
                    s_match_score = match_score
                    if match_score >= m_theta and s_match_score >= s_theta:
                        if field_value.isupper(
                        ) and match_score * s_match_score < 1:
                            continue
                        matched[match_str] = (field_value, source_match_str,
                                              match_score, s_match_score,
                                              match.size)

    if not matched:
        return None
    else:
        return sorted(matched.items(),
                      key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]),
                      reverse=True)
示例#2
0
 def is_numeric_field(self, s):
     assert(isinstance(s, string_types))
     if is_number(s):
         return False
     assert(re.fullmatch(utils.field_pattern, s))
     field_id = self.schema.get_field_id(s)
     field_node = self.schema.get_field(field_id)
     return field_node.is_numeric
def MaxAmount(amount):
    if (isinstance(amount, str) and is_number(amount)):
        _amount = safe_int(Number(amount) * (1.0001))
        return str(_amount)
    if (isinstance(amount, dict) and utils.isValidAmount(amount)):
        _value = Number(amount['value']) * (1.0001)
        amount['value'] = str(_value)
        return amount
    return Exception('invalid amount to max')
示例#4
0
 def is_field(self, s):
     if not isinstance(s, string_types):
         return False
     if is_number(s):
         return False
     if re.fullmatch(utils.field_pattern, s):
         table_name, field_name = s.split('.')
         if re.fullmatch(utils.alias_pattern, table_name):
             table_name = self.get_table_name_by_alias(table_name)
         return self.schema.is_table_name(table_name) and self.schema.is_field_name(field_name)
     else:
         return self.schema.is_field_name(s)
def Number(num):
    if (not is_number(num)):
        return float('nan')
    if (isinstance(num, bool) and num):
        return 1
    if (isinstance(num, bool) and not num):
        return 0
    if (isinstance(num, (int, float))):
        return num
    if '.' in num:
        return float(num)
    else:
        return int(num)
示例#6
0
def extract_value_spans(program_tokens, program_token_types, tu):
    values = []
    value, is_value = [], False
    for t, t_type in zip(program_tokens, program_token_types):
        if t_type == sql_tokenizer.VALUE:
            value.append(t)
        else:
            if value:
                value_str = tu.tokenizer.convert_tokens_to_string(value)
                value_str = value_str.replace(' . ', '.')
                value_str = value_str.replace(' @ ', '@')
                value_str = value_str.replace(' - ', '-')
                if not utils.is_number(value_str):
                    values.append(value_str)
                value = []
    return values
 def dispatch(self, json, is_table=False, should_quote=should_quote):
     if isinstance(json, list):
         return self.delimited_list(json)
 
     if isinstance(json, dict):
         if len(json) == 0:
             return [], []
         elif 'value' in json:
             return self.value(json)
         elif 'from' in json:
             # Nested query 'from'
             return add_parentheses(self.tokenize(json))
         elif 'query' in json:
             # Nested query 'query'
             nested_query_tokens = self.tokenize(json['query'])
             if 'name' in json:
                 return connect_by_keywords(
                     add_parentheses(nested_query_tokens), self.dispatch(json['name'], is_table=True), ['AS'])
             else:
                 return add_parentheses(nested_query_tokens)
         elif 'union' in json:
             # Nested query 'union'
             return add_parentheses(self.union(json['union']))
         elif 'intersect' in json:
             return add_parentheses(self.intersect(json['intersect']))
         elif 'except' in json:
             return add_parentheses(self.except_(json['except']))
         else:
             return self.op(json)
 
     if not isinstance(json, string_types):
         json = text(json)
     if is_table and json.lower() == 't0':
         return self.value_tokenize(json), [RESERVED_TOKEN, RESERVED_TOKEN]
     if self.keep_singleton_fields and (is_table or self.is_field(json) or json == '*'):
         if is_table:
             return [json], [TABLE]
         else:
             return [json], [FIELD]
     if self.atomic_value:
         self.constants.append(escape(json, self.value_tokenize, self.ansi_quotes, never))
         if is_number(json):
             return [self.num_token], [VALUE]
         else:
             return [self.str_token], [VALUE]
     else:
         return escape(json, self.value_tokenize, self.ansi_quotes, should_quote)
 def func(self, json):
     if op in ['<>', '>', '<', '>=', '<=', '=', '!='] and \
             isinstance(json[0], string_types) and \
             (isinstance(json[1], string_types) or (isinstance(json[1], dict) and 'literal' in json[1])):
         assert (len(json) == 2 and isinstance(json, list))
         v1, v2 = json
         if isinstance(v2, dict):
             v2 = v2['literal']
         if is_number(v2):
             return
         if v1 != v2:
             if self.is_field(v1) and not self.is_field(v2):
                 v1_id = self.schema.get_field_id(v1)
                 v1 = self.schema.get_field_signature(v1_id)
                 self.values.append((v1, v2))
     else:
         for v in json:
             self.dispatch(v)
 def _literal(self, json):
     if isinstance(json, list):
         return add_parentheses(
             (functools.reduce(lambda x, y: x+y, [self._literal(v)[0] for v in json]),
              functools.reduce(lambda x, y: x+y, [self._literal(v)[1] for v in json])))
     elif isinstance(json, string_types):
         if self.atomic_value:
             self.constants.append(escape(json, self.value_tokenize, self.ansi_quotes, never))
             if is_number(json):
                 return [self.num_token], [VALUE]
             else:
                 return [self.str_token], [VALUE]
         else:
             return escape(json, self.value_tokenize, self.ansi_quotes, always)
     else:
         tokens = self.value_tokenize(text(json))
         token_types = [VALUE for _ in tokens]
         return tokens, token_types
def preprocess_example(split,
                       example,
                       args,
                       parsed_programs,
                       text_tokenize,
                       program_tokenize,
                       post_process,
                       table_utils,
                       schema_graph,
                       vocabs,
                       verbose=False):
    tu = table_utils
    text_vocab = vocabs['text']
    program_vocab = vocabs['program']

    def get_memory_values(features, raw_text, args):
        if args.pretrained_transformer.startswith(
                'bert-') and args.pretrained_transformer.endswith('-uncased'):
            return utils.restore_feature_case(features, raw_text, tu)
        else:
            return features

    def get_text_schema_adjacency_matrix(text_features, s_M):
        schema_size = s_M.shape[0]
        text_size = len(text_features)
        full_size = schema_size + text_size
        M = ssp.lil_matrix((full_size, full_size), dtype=np.int)
        M[-schema_size:, -schema_size:] = s_M
        return M

    # sanity check
    ############################
    query_oov = False
    denormalized = False
    schema_truncated = False
    token_restored = True
    ############################

    # Text feature extraction and set program ground truth list
    if isinstance(example, Text2SQLExample):
        if args.pretrained_transformer:
            text_features = text_tokenize(example.text)
            text_tokens, token_starts, token_ends = get_memory_values(
                text_features, example.text, args)
            if not token_starts:
                token_restored = False
        else:
            text_tokens = text_tokenize(example.text, functional_tokens)
            text_features = [t.lower() for t in text_tokens]
        example.text_tokens = text_features
        example.text_ptr_values = text_tokens
        example.text_token_starts = token_starts
        example.text_token_ends = token_ends
        example.text_ids = vec.vectorize(text_features, text_vocab)
        example.text_ptr_input_ids = vec.vectorize(text_features, text_vocab)
        program_list = example.program_list
        example.values = [
            (schema_graph.get_field(cond[0]).signature, cond[2])
            for cond in example.program_ast_list_[0]['conds']
            if (isinstance(cond[2], str) and not is_number(cond[2]))
        ]
    else:
        text_tokens = example.example.text_ptr_values
        text_features = example.example.text_tokens
        program_list = example.example.program_list

    # Schema feature extraction
    if args.model_id in [BRIDGE]:
        question_encoding = example.text if args.use_picklist else None
        tables = sorted([schema_graph.get_table_id(t_name) for t_name in example.gt_table_names]) \
            if args.use_oracle_tables else None
        table_po, field_po = schema_graph.get_schema_perceived_order(tables)
        schema_features, matched_values = schema_graph.get_serialization(
            tu,
            flatten_features=True,
            table_po=table_po,
            field_po=field_po,
            use_typed_field_markers=args.use_typed_field_markers,
            use_graph_encoding=args.use_graph_encoding,
            question_encoding=question_encoding,
            top_k_matches=args.top_k_picklist_matches,
            num_values_per_field=args.num_values_per_field,
            no_anchor_text=args.no_anchor_text)
        example.matched_values = matched_values
        example.input_tokens, example.input_ptr_values, num_excluded_tables, num_excluded_fields = \
            get_table_aware_transformer_encoder_inputs(text_tokens, text_features, schema_features, table_utils)
        schema_truncated = (num_excluded_fields > 0)
        num_included_nodes = schema_graph.get_num_perceived_nodes(
            table_po) + 1 - num_excluded_tables - num_excluded_fields
        example.ptr_input_ids = vec.vectorize(example.input_tokens, text_vocab)
        if args.read_picklist:
            example.transformer_output_value_mask, value_features, value_tokens = \
                get_transformer_output_value_mask(example.input_tokens, matched_values, tu)
        example.primary_key_ids = schema_graph.get_primary_key_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.foreign_key_ids = schema_graph.get_foreign_key_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.field_type_ids = schema_graph.get_field_type_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.table_masks = schema_graph.get_table_masks(num_included_nodes,
                                                           table_po=table_po,
                                                           field_po=field_po)
        example.field_table_pos = schema_graph.get_field_table_pos(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.schema_M = schema_graph.adj_matrix
        example.M = get_text_schema_adjacency_matrix(text_features,
                                                     example.schema_M)
    else:
        num_included_nodes = schema_graph.num_nodes

    # Value copy feature extraction
    if args.read_picklist:
        constant_memory_features = text_features + value_features
        constant_memory = text_tokens + value_tokens
        example.text_ptr_values = constant_memory
    else:
        constant_memory_features = text_features
    constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in(
        constant_memory_features, program_vocab)
    if isinstance(example, Text2SQLExample):
        example.text_ptr_value_ids = constant_ptr_value_ids
    example.ptr_value_ids = constant_ptr_value_ids + [
        program_vocab.size + len(constant_memory_features) + x
        for x in range(num_included_nodes)
    ]

    if not args.leaderboard_submission:
        for j, program in enumerate(program_list):
            if isinstance(example, Text2SQLExample):
                # Model II. Bridge output
                program_singleton_field_tokens, program_singleton_field_token_types = \
                    tok.wikisql_struct_to_tokens(example.program_ast_, schema_graph, tu)
                program_singleton_field_tokens = [
                    START_TOKEN
                ] + program_singleton_field_tokens + [EOS_TOKEN]
                program_singleton_field_token_types = \
                    [RESERVED_TOKEN_TYPE] + program_singleton_field_token_types + [RESERVED_TOKEN_TYPE]
                example.program_singleton_field_tokens_list.append(
                    program_singleton_field_tokens)
                example.program_singleton_field_token_types_list.append(
                    program_singleton_field_token_types)
                program_singleton_field_input_ids = vec.vectorize_singleton(
                    program_singleton_field_tokens,
                    program_singleton_field_token_types, program_vocab)
                example.program_singleton_field_input_ids_list.append(
                    program_singleton_field_input_ids)
            else:
                # Model II. Bridge output
                example.program_singleton_field_input_ids_list.append(
                    example.example.program_singleton_field_input_ids_list[j])
                program_singleton_field_tokens = example.example.program_singleton_field_tokens_list[
                    j]
                program_singleton_field_token_types = example.example.program_singleton_field_token_types_list[
                    j]

            program_field_ptr_value_ids = vec.vectorize_field_ptr_out(
                program_singleton_field_tokens,
                program_singleton_field_token_types,
                program_vocab,
                constant_unique_input_ids,
                max_memory_size=len(constant_memory_features),
                schema=schema_graph,
                num_included_nodes=num_included_nodes)
            example.program_text_and_field_ptr_value_ids_list.append(
                program_field_ptr_value_ids)

            table_ids = [
                schema_graph.get_table_id(table_name)
                for table_name in example.gt_table_names_list[j]
            ]
            example.table_ids_list.append(table_ids)
            assert ([schema_graph.get_table(x).name
                     for x in table_ids] == example.gt_table_names)

            # sanity check
            ############################
            #   NL+Schema pointer output contains tokens that does not belong to any of the following categories
            if verbose:
                if program_vocab.unk_id in program_field_ptr_value_ids:
                    unk_indices = [
                        i for i, x in enumerate(program_field_ptr_value_ids)
                        if x == program_vocab.unk_id
                    ]
                    print('OOV II: {}'.format(' '.join([
                        program_singleton_field_tokens[i] for i in unk_indices
                    ])))
                    example.pretty_print(
                        schema=schema_graph,
                        de_vectorize_ptr=vec.de_vectorize_ptr,
                        de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                        rev_vocab=program_vocab,
                        post_process=post_process,
                        use_table_aware_te=(args.model_id in [BRIDGE]))
                    query_oov = True
            if program_vocab.unk_field_id in program_field_ptr_value_ids:
                example.pretty_print(
                    schema=schema_graph,
                    de_vectorize_ptr=vec.de_vectorize_ptr,
                    de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                    rev_vocab=program_vocab,
                    post_process=post_process,
                    use_table_aware_te=(args.model_id in [BRIDGE]))
            if program_vocab.unk_table_id in program_field_ptr_value_ids:
                example.pretty_print(
                    schema=schema_graph,
                    de_vectorize_ptr=vec.de_vectorize_ptr,
                    de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                    rev_vocab=program_vocab,
                    post_process=post_process,
                    use_table_aware_te=(args.model_id in [BRIDGE]))
            ############################

            # Store the ground truth queries after preprocessing to run a relaxed evaluation or
            # to evaluate with partial queries
            if split == 'dev':
                input_tokens = text_tokens
                if args.model_id in [BRIDGE]:
                    _p = vec.de_vectorize_field_ptr(
                        program_field_ptr_value_ids,
                        program_vocab,
                        input_tokens,
                        schema=schema_graph,
                        post_process=post_process)
                else:
                    _p = program
                example.gt_program_list.append(_p)

            # sanity check
            ############################
            # try:
            #     assert(equal_ignoring_trivial_diffs(_p, program.lower(), verbose=True))
            # except Exception:
            #     print('_p:\t\t{}'.format(_p))
            #     print('program:\t{}'.format(program))
            #     print()
            #     import pdb
            #     pdb.set_trace()
            ############################

        example.run_unit_tests()

    return query_oov, denormalized, schema_truncated, token_restored