Ejemplo n.º 1
0
def read_train(tree_file_path, align_file_path, vocab, max_size, vocab_size,
               tree_parser):
    """
    訓練用のファイルを読み込むための関数。
    """
    trees = []
    with codecs.open(tree_file_path, 'r',
                     'utf-8') as tree_file_path, gzip.open(
                         align_file_path, 'r') as align_file:
        for i, tree_line in enumerate(tree_file_path):
            if (i + 1) % 1000 == 0:
                print("%d lines have been read..." % (i + 1))
            tree_line = tree_line.strip()
            align_file.readline()  # 3n+1行目は不要なので飛ばす
            # 原言語の読み込み(3n+2行目)
            e_words = align_file.readline().strip().decode('utf-8').split()
            # 目的言語の読み込み(3n行目)
            f_line = align_file.readline().strip().decode('utf-8')
            f_words = re.split('\(\{|\}\)', f_line)[:-1]
            tmp_vocab_dict = collections.defaultdict(lambda: 0)

            e_wordlist = []
            # 原言語側の処理(単語がかぶっている時は単語前に数字をつけて区別)
            for e_word in e_words:
                tmp_vocab_dict[e_word] += 1
                e_wordlist.append(str(tmp_vocab_dict[e_word]) + '_' + e_word)

            f_word_dst = []
            align = []
            # for j in range(1, len(f_words)//2):
            # NULLアライメントは考慮しないのでfor文は1から
            for j in range(len(f_words) // 2):
                f_word = f_words[2 * j]  # 目的言語側の単語
                f_align = f_words[2 * j + 1].strip().split()  # 目的言語のアライメント先
                f_word_dst.append(
                    (f_word, [e_wordlist[int(k) - 1] for k in f_align]))
                align.extend([int(k) for k in f_align])

            if tree_parser == 'enju':
                tree = EnjuXmlParser(tree_line)
            elif tree_parser == 's':
                tree = STreeParser(tree_line)
            else:
                print('invalid tree parser', file=sys.stderr)
                sys.exit(1)
            tree = tree.parse(tree.root)
            if tree['status'] == 'failed':
                continue
            trees.append(
                convert_tree_train(vocab, tree, e_wordlist, f_word_dst, 0,
                                   kendall_tau(align), vocab_size))
            if max_size and len(trees) >= max_size:
                break
    return trees, vocab
def cky_with_score(leaves, model, vocab_dict):
    """
    CKYアルゴリズムっぽい感じで構文解析する
    :param leaves:
    :return ground_truth_tree, score, loss, node_list:
    """
    last_node = None
    if leaves[-1].word in ['.', '?']:
        leaves, last_node = leaves[:-1], leaves[-1]
    trees = [
        [[model.leaf(xp.array([[vocab_dict[n.word]]]))]] if n.word
        in vocab_dict else [[model.leaf(xp.array([[vocab_dict['<UNK>']]]))]]
        for n in leaves
    ]  # 葉ノードをベクトル表現に
    len_leaves = len(leaves)
    for d in range(1, len_leaves):
        for i in range(len_leaves - d):
            nodes = []
            for j in range(len(trees[i])):
                # この段階でmax(Kendall's τ)&min(num_swap)&max(score)のものを逐次的に選んでいく
                # nodes += make_candidate(trees[i][j], trees[i+j+1][-j-1 + (len(trees[i]) - len(trees[i+j+1]))])
                node = None
                max_tau, min_swap = None, None
                max_score = Variable(xp.array([[0]], dtype=xp.float32))
                left_nodes, right_nodes = trees[i][j], trees[i + j + 1][
                    -j - 1 + (len(trees[i]) - len(trees[i + j + 1]))]
                for left_node in left_nodes:
                    for right_node in right_nodes:  # 各ペアに対して
                        left_align = flatten_node(left_node)
                        right_align = flatten_node(right_node)
                        if kendall_tau(left_align +
                                       right_align) < kendall_tau(right_align +
                                                                  left_align):
                            cand_node = Node(
                                left_node, right_node, 'Inverted',
                                left_node.swap + right_node.swap + 1)
                            cand_tau = kendall_tau(right_align + left_align)
                        else:
                            cand_node = Node(left_node, right_node, 'Straight',
                                             left_node.swap + right_node.swap)
                            cand_tau = kendall_tau(left_align + right_align)
def make_candidate(left_nodes, right_nodes):
    nodes = []
    if isinstance(left_nodes, Nodes):
        left_nodes = left_nodes.nodes
    if isinstance(right_nodes, Nodes):
        right_nodes = right_nodes.nodes
    for left_node in left_nodes:
        for right_node in right_nodes:
            left_align = flatten_node(left_node)
            right_align = flatten_node(right_node)
            if kendall_tau(left_align +
                           right_align) < kendall_tau(right_align +
                                                      left_align):
                nodes.append(
                    Node(left_node, right_node, "Inverted",
                         left_node.swap + right_node.swap + 1))
            else:
                nodes.append(
                    Node(left_node, right_node, "Straight",
                         left_node.swap + right_node.swap))
    return nodes
def cky(leaves):
    """
    CKYアルゴリズムを模倣して木の構築を行う
    A -> BCのようなルールはなく全てのノードで候補があるため、各ノードで順位相関係数が高くなるように枝刈りを行う
    """
    # print_message("Construct Tree...")
    last_node = None
    # 最後が記号(.?)の時は途中で結合して欲しくないので退避
    if leaves[-1].word in [".", "?"]:
        leaves, last_node = leaves[:-1], leaves[-1]
    trees = [[[n]] for n in leaves]  # 葉ノードの構築
    len_leaves = len(leaves)
    for d in range(1, len_leaves):
        for i in range(len_leaves - d):
            nodes = []
            for j in range(len(trees[i])):  # 被覆するスパンに対して全探索
                # 候補ノードの作成 下の場合分けは不要なので後で消す
                if len(trees[i + j + 1]) == len(trees[i]):
                    nodes += make_candidate(trees[i][j],
                                            trees[i + j + 1][-j - 1])
                else:
                    nodes += make_candidate(
                        trees[i][j],
                        trees[i + j +
                              1][-j - 1 +
                                 (len(trees[i]) - len(trees[i + j + 1]))])
            # 順位相関係数が高いもの
            max_tau = max(kendall_tau(flatten_node(n)) for n in nodes)
            nodes = [
                n for n in nodes if kendall_tau(flatten_node(n)) == max_tau
            ]
            # 交換回数が少ないもの
            min_swap = min(n.swap for n in nodes)
            nodes = [n for n in nodes if n.swap == min_swap]
            trees[i].append(Nodes(nodes))
    ts = trees[0][-1]
    if last_node:
        ts = Nodes([Node(n, last_node, "Straight", n.swap) for n in ts.nodes])
    return ts
def main():
    global xp
    args = parse()
    print_message("Prepare training data...")

    model = RecursiveNet(args.vocab_size, args.embed_size, args.hidden_size,
                         args.label)

    if args.gpus >= 0:  # GPU使用の設定
        cuda.get_device_from_id(args.gpus).use()
        model.to_gpu(args.gpus)
        xp = cuda.cupy

    vocab_dict = make_vocabulary(args.alignmentfile, args.vocab_size)

    optm = optimizers.Adam()
    optm.setup(model)
    optm.add_hook(chainer.optimizer.WeightDecay(0.0001))
    optm.add_hook(chainer.optimizer.GradientClipping(5))

    # for c_t in data_prepare(args.alignmentfile):
    #     pprint.pprint(c_t)
    #     input()

    model = train(args.alignmentfile, args.epoch, model, optm, args.batchsize,
                  vocab_dict, args.num_trees)
    with gzip.open(args.alignmentfile, 'rb', 'utf-8') as f:
        for _ in f:
            s = f.readline().strip().decode('utf-8').split(' ')
            target = f.readline().strip().decode('utf-8')
            target_word_align = re.split('\(\{|\}\)', target)[:-1]
            pred_idxes = predict(model, s, vocab_dict)
            print(pred_idxes)
            pred_idxes = flatten_tuple_tree(pred_idxes)
            s_idxes = []
            for a in target_word_align[1::2]:
                for _a in a.strip().split():
                    s_idxes.append(pred_idxes.index(int(_a) - 1))
            print("kendall's tau: %.4f" % kendall_tau(s_idxes))
            print(' '.join(s[pred_idx] for pred_idx in pred_idxes))
Ejemplo n.º 6
0
def convert_tree_dev(vocab, node, e_list, f_dst_list, j, tau):
    """
    :param e_list: 原言語のリスト
    :param f_dst_list: 目的言語がどの原言語に対応しているかのリスト
    :param j: 被覆している単語の右端が???
    :param tau: 今のケンダールのτ
    """
    if node['tag'] == 'sentence':
        children = []
        span = (j, j)
        for child in node['children']:
            span, child_node = convert_tree_dev(vocab, child, e_list,
                                                f_dst_list, span[1], tau)
            children.append(child_node)
        return {'tag': node['tag'], 'children': children}
    elif node['tag'] == 'cons':
        assert len(node['children']) == 1 or len(node['children']) == 2
        if len(node['children']) == 1:
            return convert_tree_dev(vocab, node['children'][0], e_list,
                                    f_dst_list, j, tau)
        else:
            swap = 0
            left_span, left_node = convert_tree_dev(vocab, node['children'][0],
                                                    e_list, f_dst_list, j, tau)
            right_span, right_node = convert_tree_dev(vocab,
                                                      node['children'][1],
                                                      e_list, f_dst_list,
                                                      max(left_span), tau)
            # 並び替え候補の作成
            span_min, span_max = min(left_span), max(right_span)
            # 並び替えなしの時の単語の順番
            tmp_e_list = [e_list[i] for i in range(span_min - 1)] + [e_list[i - 1] for i in left_span] \
                         + [e_list[i - 1] for i in right_span] + [e_list[i] for i in range(span_max, len(e_list))]
            align = []
            for f_w_dst in f_dst_list:
                for e_dst in f_w_dst[1]:
                    if e_dst in tmp_e_list:
                        align.append(tmp_e_list.index(e_dst))
            tmp_tau_1 = kendall_tau(align)
            # 並び替えあり
            tmp_e_list = [e_list[i] for i in range(span_min - 1)]
            tmp_e_list.extend([e_list[i - 1] for i in right_span])
            tmp_e_list.extend([e_list[i - 1] for i in left_span])
            tmp_e_list.extend(
                [e_list[i] for i in range(span_max, len(e_list))])
            align = []
            for f_w_dst in f_dst_list:
                for e_dst in f_w_dst[1]:
                    if e_dst in tmp_e_list:
                        align.append(tmp_e_list.index(e_dst))
            tmp_tau_2 = kendall_tau(align)
            if tmp_tau_2 > tmp_tau_1:
                swap = 1
                span_list = right_span + left_span
            else:
                span_list = left_span + right_span
            text = node['text'] if 'text' in node else ""
            tail = node['tail'] if 'tail' in node else ""
            return span_list, {
                'tag': node['tag'],
                'label': swap,
                'node': (left_node, right_node),
                'cat': node['cat'],
                'text': text,
                'tail': tail
            }
    elif node['tag'] == 'tok':
        t = vocab[node['text'].lower()] if node['text'].lower(
        ) in vocab else vocab['<UNK>']
        return [j + 1], {
            'tag': node['tag'],
            'node': t,
            'text': node['text'],
            'pos': node['pos']
        }