def choose_random_forms(ltm_paradigms,
                        vocab,
                        gold_pos,
                        morph,
                        n_samples=10,
                        gold_word=None):
    candidates = set()

    #lemma_tag_pairs = ltm_paradigms.keys()
    #test_lemmas = [l for l, t in lemma_tag_pairs]

    for lemma in ltm_paradigms:
        poses = list(ltm_paradigms[lemma].keys())
        if len(set(poses)) == 1 and poses.pop() == gold_pos:
            form = ltm_paradigms[lemma][gold_pos][morph]
            _, morph_alt = alt_numeral_morph(morph)
            form_alt = ltm_paradigms[lemma][gold_pos][morph_alt]

            if not is_good_form(gold_word, form, morph, lemma, gold_pos, vocab,
                                ltm_paradigms):
                continue

            candidates.add((lemma, form, form_alt))

    if len(candidates) > n_samples:
        return random.sample(candidates, n_samples)
    else:
        return random.sample(candidates, len(candidates))
def main():
    parser = argparse.ArgumentParser(
        description=
        'Extracting dependency-based long-distance agreement patterns')

    parser.add_argument(
        '--treebank',
        type=str,
        required=True,
        help='Path of the input treebank file (in a column format)')
    parser.add_argument('--output',
                        type=str,
                        required=True,
                        help="Path for the output files")
    parser.add_argument(
        '--features',
        type=str,
        default="Number",
        help=
        "A list of morphological features which will be used, in Number|Case|Gender format"
    )
    parser.add_argument('--freq',
                        type=int,
                        default=5,
                        help="minimal frequency")
    parser.add_argument('--vocab',
                        type=str,
                        required=False,
                        help="LM vocab - to compute which sentences have OOV")
    parser.add_argument('--paradigms',
                        type=str,
                        required=False,
                        help="File with morphological paradigms - to compute"
                        "which sentences have both target pairs")

    args = parser.parse_args()

    if args.vocab:
        vocab = load_vocab(args.vocab)
    else:
        vocab = []

    print("Loading trees")
    trees = tm.load_trees_from_conll(args.treebank)

    # needed for original UD treebanks (e.g. Italian) which contain spans, e.g. 10-12
    # annotating mutlimorphemic words as several nodes in the tree
    for t in trees:
        t.remerge_segmented_morphemes()

    if args.features:
        args.features = args.features.split("|")
        print("Features", args.features)

    print("Extracting contexts")
    context_left_deps, context_right_deps = morph_contexts_frequencies(
        trees, args.features)

    # filtering very infrequent cases
    filter_threshold = 1
    context_left_deps = defaultdict(
        lambda: defaultdict(int), {
            c: defaultdict(int, {
                lr: freq
                for lr, freq in d.items() if freq > filter_threshold
            })
            for c, d in context_left_deps.items()
        })
    context_right_deps = defaultdict(
        lambda: defaultdict(int), {
            c: defaultdict(int, {
                lr: freq
                for lr, freq in d.items() if freq > filter_threshold
            })
            for c, d in context_right_deps.items()
        })

    print("Finding good patterns")
    good_patterns_left = find_good_patterns(context_left_deps, args.freq)
    good_patterns_right = find_good_patterns(context_right_deps, args.freq)

    f_out = open(args.output + "/patterns.txt", "w")

    print("Saving patterns and sentences matching them")

    ltm_paradigms = ltm_to_word(read_paradigms(args.paradigms))

    for p in good_patterns_left:
        f_out.write("L\t" + "_".join(x for x in p[0]) + "\t" +
                    "\t".join(p[1:]) + "\n")
        print("L\t" + "_".join(x
                               for x in p[0]) + "\t" + "\t".join(p[1:]) + "\n")

        f_out_grep = open(args.output + "/L_" + "_".join(x for x in p[0]), "w")
        for context, l, r, t, nodes in grep_morph_pattern(
                trees, p[0], p[1:], tm.Arc.LEFT, args.features):
            #print(l.morph + " " + r.morph + "\t" + l.word + " " + " ".join([n.word for n in nodes]) + " " + r.word)

            in_vocab = all([n.word in vocab for n in nodes + [l, r]])
            in_paradigms = is_good_form(r.word, r.word, r.morph, r.lemma,
                                        r.pos, vocab, ltm_paradigms)
            f_out_grep.write(
                features(l.morph, args.features) + " " +
                features(r.morph, args.features) + "\t" + str(in_vocab) +
                str(in_paradigms) + "\t" + l.word + " " +
                " ".join([n.word for n in nodes]) + " " + r.word + "\n")
        f_out_grep.close()

    for p in good_patterns_right:
        f_out.write("R\t" + "_".join(x for x in p[0]) + "\t" +
                    "\t".join(p[1:]) + "\n")
        print("R\t" + "_".join(x
                               for x in p[0]) + "\t" + "\t".join(p[1:]) + "\n")

        f_out_grep = open(args.output + "/R_" + "_".join(x for x in p[0]), "w")
        for context, l, r, t, nodes in grep_morph_pattern(
                trees, p[0], p[1:], tm.Arc.RIGHT, args.features):
            #print(l.morph + " " + r.morph + "\t" + l.word + " " + " ".join([n.word for n in nodes]) + " " + r.word)
            in_vocab = all([n.word in vocab for n in nodes + [l, r]])
            in_paradigms = is_good_form(r.word, r.word, r.morph, r.lemma,
                                        r.pos, vocab, ltm_paradigms)
            f_out_grep.write(
                features(l.morph, args.features) + " " +
                features(r.morph, args.features) + "\t" + str(in_vocab) +
                str(in_paradigms) + "\t" + l.word + " " +
                " ".join([n.word for n in nodes]) + " " + r.word + "\n")
        f_out_grep.close()

    f_out.close()
def generate_morph_pattern_test(trees,
                                pattern,
                                paradigms,
                                vocab,
                                n_sentences=10):
    arc_dir, context = pattern.split("\t")[:2]
    context = tuple(context.split("_"))
    l_values = pattern.split("\t")[2:]
    pattern_id = pattern.replace("\t", "!")

    ltm_paradigms = ltm_to_word(paradigms)

    output = []
    constr_id = 0

    n_vocab_unk = 0
    n_paradigms_unk = 0
    # 'nodes' constitute Y, without X or Z included
    for context, l, r, t, nodes in grep_morph_pattern(trees, context, l_values,
                                                      arc_dir):
        #pos_constr = "_".join(n.pos for n in t.nodes[l.index - 1: r.index])

        # filter model sentences with unk and the choice word not in vocab
        if not all([n.word in vocab for n in nodes + [l, r]]):
            n_vocab_unk += 1
            continue
        if not is_good_form(r.word, r.word, r.morph, r.lemma, r.pos, vocab,
                            ltm_paradigms):
            n_paradigms_unk += 1
            continue

        prefix = " ".join(n.word for n in t.nodes[:r.index])

        for i in range(n_sentences):
            # sent_id = 0 - original sentence with good lexical items, other sentences are generated
            if i == 0:
                new_context = " ".join(n.word for n in t.nodes)
                form = r.word
                form_alt = get_alt_form(r.lemma, r.pos, r.morph, ltm_paradigms)
                lemma = r.lemma
            else:
                new_context = generate_context(t.nodes, paradigms, vocab)
                random_forms = choose_random_forms(ltm_paradigms,
                                                   vocab,
                                                   r.pos,
                                                   r.morph,
                                                   n_samples=1,
                                                   gold_word=r.word)
                if len(random_forms) > 0:
                    lemma, form, form_alt = random_forms[0]
                else:
                    # in rare cases, there is no (form, form_alt) both in vocab
                    # original form and its alternation are not found because e.g. one or the other is not in paradigms
                    # (they should anyway be in the vocabulary)
                    lemma, form = r.lemma, r.word
                    form_alt = get_alt_form(r.lemma, r.pos, r.morph,
                                            ltm_paradigms)

            # constr_id sent_id Z_index Z_pos Z_gold_morph
            gold_str = "\t".join([
                pattern_id,
                str(constr_id),
                str(i),
                str(r.index - 1), r.pos, r.morph, form, form_alt, lemma,
                str(l.index - 1), l.pos, prefix
            ]) + "\n"

            output.append((new_context + " <eos>\n", gold_str))

        constr_id += 1

    print("Problematic sentences vocab/paradigms", n_vocab_unk,
          n_paradigms_unk)
    return output