def set_question_literal_info(question_toks, column_names, table_names,
                              question_literal, q_literal_type):
    """TODO: Docstring for set_question_literal_info.

    Args:
        question_toks (TYPE): NULL
        column_names (TYPE): NULL
        table_names (TYPE): NULL
        question_literal (TYPE): [out]
        q_literal_type (TYPE): [out]

    Returns: TODO

    Raises: NULL
    """
    idx = 0
    while idx < len(question_toks):
        # fully header
        end_idx, header = utils.fully_part_header(question_toks, idx,
                                                  column_names)
        if header:
            question_literal.append(question_toks[idx:end_idx])
            q_literal_type.append(["col"])
            idx = end_idx
            continue

        # check for table
        end_idx, tname = utils.group_header(question_toks, idx, table_names)
        if tname:
            question_literal.append(question_toks[idx:end_idx])
            q_literal_type.append(["table"])
            idx = end_idx
            continue

        # check for column
        end_idx, header = utils.group_header(question_toks, idx, column_names)
        if header:
            question_literal.append(question_toks[idx:end_idx])
            q_literal_type.append(["col"])
            idx = end_idx
            continue

        if utils.group_digital(question_toks, idx):
            question_literal.append(question_toks[idx:idx + 1])
            q_literal_type.append(["value"])
            idx += 1
            continue

        question_literal.append([question_toks[idx]])
        q_literal_type.append(['NONE'])
        idx += 1
예제 #2
0
def process_datas(datas, args):
    """

    :param datas:
    :param args:
    :return:
    """
    with open(os.path.join(args.conceptNet, 'english_RelatedTo.pkl'), 'rb') as f:
        english_RelatedTo = pickle.load(f)

    with open(os.path.join(args.conceptNet, 'english_IsA.pkl'), 'rb') as f:
        english_IsA = pickle.load(f)

    # copy of the origin question_toks
    for d in datas:
        if 'origin_question_toks' not in d:
            d['origin_question_toks'] = d['question_toks']

    for entry in datas:
        entry['question_toks'] = symbol_filter(entry['question_toks'])
        origin_question_toks = symbol_filter([x for x in entry['origin_question_toks'] if x.lower() != 'the'])
        question_toks = [wordnet_lemmatizer.lemmatize(x.lower()) for x in entry['question_toks'] if x.lower() != 'the']

        entry['question_toks'] = question_toks

        table_names = []
        table_names_pattern = []

        for y in entry['table_names']:
            x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')]
            table_names.append(" ".join(x))
            x = [re_lemma(x.lower()) for x in y.split(' ')]
            table_names_pattern.append(" ".join(x))

        header_toks = []
        header_toks_list = []

        header_toks_pattern = []
        header_toks_list_pattern = []

        for y in entry['col_set']:
            x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')]
            header_toks.append(" ".join(x))
            header_toks_list.append(x)

            x = [re_lemma(x.lower()) for x in y.split(' ')]
            header_toks_pattern.append(" ".join(x))
            header_toks_list_pattern.append(x)

        num_toks = len(question_toks)
        idx = 0
        tok_concol = []
        type_concol = []
        nltk_result = nltk.pos_tag(question_toks)

        while idx < num_toks:

            # fully header
            end_idx, header = fully_part_header(question_toks, idx, num_toks, header_toks)
            if header:
                tok_concol.append(question_toks[idx: end_idx])
                type_concol.append(["col"])
                idx = end_idx
                continue

            # check for table
            end_idx, tname = group_header(question_toks, idx, num_toks, table_names)
            if tname:
                tok_concol.append(question_toks[idx: end_idx])
                type_concol.append(["table"])
                idx = end_idx
                continue

            # check for column
            end_idx, header = group_header(question_toks, idx, num_toks, header_toks)
            if header:
                tok_concol.append(question_toks[idx: end_idx])
                type_concol.append(["col"])
                idx = end_idx
                continue

            # check for partial column
            end_idx, tname = partial_header(question_toks, idx, header_toks_list)
            if tname:
                tok_concol.append(tname)
                type_concol.append(["col"])
                idx = end_idx
                continue

            # check for aggregation
            end_idx, agg = group_header(question_toks, idx, num_toks, AGG)
            if agg:
                tok_concol.append(question_toks[idx: end_idx])
                type_concol.append(["agg"])
                idx = end_idx
                continue

            if nltk_result[idx][1] == 'RBR' or nltk_result[idx][1] == 'JJR':
                tok_concol.append([question_toks[idx]])
                type_concol.append(['MORE'])
                idx += 1
                continue

            if nltk_result[idx][1] == 'RBS' or nltk_result[idx][1] == 'JJS':
                tok_concol.append([question_toks[idx]])
                type_concol.append(['MOST'])
                idx += 1
                continue

            # string match for Time Format
            if num2year(question_toks[idx]):
                question_toks[idx] = 'year'
                end_idx, header = group_header(question_toks, idx, num_toks, header_toks)
                if header:
                    tok_concol.append(question_toks[idx: end_idx])
                    type_concol.append(["col"])
                    idx = end_idx
                    continue

            def get_concept_result(toks, graph):
                for begin_id in range(0, len(toks)):
                    for r_ind in reversed(range(1, len(toks) + 1 - begin_id)):
                        tmp_query = "_".join(toks[begin_id:r_ind])
                        if tmp_query in graph:
                            mi = graph[tmp_query]
                            for col in entry['col_set']:
                                if col in mi:
                                    return col

            end_idx, symbol = group_symbol(question_toks, idx, num_toks)
            if symbol:
                tmp_toks = [x for x in question_toks[idx: end_idx]]
                assert len(tmp_toks) > 0, print(symbol, question_toks)
                pro_result = get_concept_result(tmp_toks, english_IsA)
                if pro_result is None:
                    pro_result = get_concept_result(tmp_toks, english_RelatedTo)
                if pro_result is None:
                    pro_result = "NONE"
                for tmp in tmp_toks:
                    tok_concol.append([tmp])
                    type_concol.append([pro_result])
                    pro_result = "NONE"
                idx = end_idx
                continue

            end_idx, values = group_values(origin_question_toks, idx, num_toks)
            if values and (len(values) > 1 or question_toks[idx - 1] not in ['?', '.']):
                tmp_toks = [wordnet_lemmatizer.lemmatize(x) for x in question_toks[idx: end_idx] if x.isalnum() is True]
                assert len(tmp_toks) > 0, print(question_toks[idx: end_idx], values, question_toks, idx, end_idx)
                pro_result = get_concept_result(tmp_toks, english_IsA)
                if pro_result is None:
                    pro_result = get_concept_result(tmp_toks, english_RelatedTo)
                if pro_result is None:
                    pro_result = "NONE"
                for tmp in tmp_toks:
                    tok_concol.append([tmp])
                    type_concol.append([pro_result])
                    pro_result = "NONE"
                idx = end_idx
                continue

            result = group_digital(question_toks, idx)
            if result is True:
                tok_concol.append(question_toks[idx: idx + 1])
                type_concol.append(["value"])
                idx += 1
                continue
            if question_toks[idx] == ['ha']:
                question_toks[idx] = ['have']

            tok_concol.append([question_toks[idx]])
            type_concol.append(['NONE'])
            idx += 1
            continue

        entry['question_arg'] = tok_concol
        entry['question_arg_type'] = type_concol
        entry['nltk_pos'] = nltk_result

    return datas
def process_cells(dct_cells, data, question_tokens, is_train):
    """
    process cells 
    dct_cells: db table's cell: {db_id: cells}
    data : train data
    is_train: if train is True else is False
    :return: [value ,value_features,val_col_tab]
    """
    cell_columns = []
    dct_value_score = defaultdict(int)

    if is_train:
        gold_values = utils.extract_gold_value(data['sql'])
        for idx, val in enumerate(gold_values):
            if utils.is_simple_float(val) and utils.is_int_float(float(val)):
                gold_values[idx] = str(int(float(val)))
        dct_value_score.update((v, 100) for v in set(gold_values))

    curr_cells = dct_cells[data['db_id']]
    q_tok_set = set(question_tokens)
    for table_name, table in curr_cells['tables'].items():
        table_score = utils.match_score(table_name, q_tok_set)
        col_dtypes = table['type']
        cols = table['header']
        col_scores = [
            utils.match_score(c, q_tok_set, base=table_score) for c in cols
        ]
        for row in table['cell']:
            for one_cell, col_score, col_dtype in zip(row, col_scores,
                                                      col_dtypes):
                if col_dtype in ('number', 'time') or one_cell == "":
                    continue
                dct_value_score[one_cell] += utils.match_score(one_cell,
                                                               q_tok_set,
                                                               base=col_score)

    for idx, word in enumerate(question_tokens[:-1]):
        dct_value_score[word] += 0.1
        dct_value_score[word + question_tokens[idx + 1]] += 0.1
    dct_value_score[question_tokens[-1]] += 0.1
    if len(dct_value_score) <= g_max_values:
        values = list(sorted(dct_value_score.keys()))
    else:
        values = [
            x[0] for x in sorted(dct_value_score.items(),
                                 key=lambda x: x[1],
                                 reverse=True)[:g_max_values]
        ]

    value_features = np.zeros((len(values), 2), dtype=np.int)  # EM, PM
    set_p_match_fea(data['question_tokens'], values, 1, value_features)

    idx = 0
    values_set = set(values)
    question_toks = data['question_tokens']
    while idx < len(question_toks):
        end_idx, cellvalue = utils.group_header(question_toks, idx, values_set)
        if cellvalue:
            idx = end_idx
            for i, v in enumerate(values):
                if cellvalue == v:
                    value_features[i][0] = 5
        else:
            idx += 1

    return values, value_features, []