def eval_sdp_and_ensemble(parser,
                          devfile,
                          dep_dev_output,
                          save_dir,
                          lang,
                          logger,
                          do_eval=True):
    long_sent: dict = load_json(devfile.replace('.short.conllu', '.long.json'))
    long_sent = dict((int(x), y) for x, y in long_sent.items())
    sdp_dev_output = f'{save_dir}/{os.path.basename(devfile.replace(".conllu", ".sdp.pred.conllu"))}'
    sdp_dev_output = sdp_dev_output.replace('.short', '')
    if not os.path.isfile(sdp_dev_output) or do_eval:
        if not parser.model:
            parser.load(save_dir)
        scores = parser.evaluate(devfile,
                                 save_dir,
                                 warm_up=False,
                                 ret_scores=True,
                                 logger=logger,
                                 batch_size=256 if lang == 'cs' else None)[-1]
        sdp_to_dag(parser, scores, sdp_dev_output, long_sent)
    score = evaluate(devfile.replace('.short', ''), sdp_dev_output)
    final_sdp_dev_output = sdp_dev_output.replace('.conllu', '.fixed.conllu')
    sdp_elas = score["ELAS"].f1
    sdp_clas = score["CLAS"].f1
    logger.info(f'SDP score for {lang}:')
    logger.info(f'ELAS: {sdp_elas * 100:.2f} - CLAS:{sdp_clas * 100:.2f}')
    print(f'Model saved in {save_dir}')
    ensemble_output = f'{save_dir}/{os.path.basename(devfile.replace(".conllu", ".ensemble.pred.conllu"))}'
    if not os.path.isfile(sdp_dev_output) or do_eval:
        sdp_to_dag(parser, scores, ensemble_output, long_sent, dep_dev_output)
    score = evaluate(devfile.replace('.short', ''), ensemble_output)
    final_ensemble_output = ensemble_output.replace('.conllu', '.fixed.conllu')
    logger.info(f'Ensemble score for {lang}:')
    ensemble_elas = score["ELAS"].f1
    logger.info(
        f'ELAS: {ensemble_elas * 100:.2f} - CLAS:{score["CLAS"].f1 * 100:.2f}')
    return (sdp_elas, final_sdp_dev_output), (ensemble_elas,
                                              final_ensemble_output)
from edparser.metrics.parsing.iwpt20_eval import evaluate
from edparser.components.parsers.biaffine_parser import BiaffineTransformerDependencyParser
from iwpt2020 import cdroot

cdroot()
save_dir = 'data/model/iwpt2020/en_bert_large_dep'
parser = BiaffineTransformerDependencyParser()
dataset = 'data/iwpt2020/train-dev/'
trnfile = f'{dataset}UD_English-EWT/en_ewt-ud-train.enhanced_collapse_empty_nodes.conllu'
devfile = f'{dataset}UD_English-EWT/en_ewt-ud-dev.enhanced_collapse_empty_nodes.conllu'
testfile = devfile
# parser.fit(trnfile,
#            devfile,
#            save_dir, 'bert-large-uncased-whole-word-masking',
#            batch_size=128,
#            warmup_steps_ratio=.1,
#            samples_per_batch=150,
#            # max_samples_per_batch=32,
#            transformer_dropout=.33,
#            learning_rate=2e-3,
#            learning_rate_transformer=1e-5,
#            epochs=1
#            )
# parser.load(save_dir, tree='tarjan')
output = f'{testfile.replace(".conllu", ".pred.conllu")}'
# parser.evaluate(devfile, save_dir, warm_up=False, output=output)
score = evaluate(testfile, output)
print(
    f'ELAS: {score["ELAS"].f1 * 100:.2f} - CLAS:{score["CLAS"].f1 * 100:.2f}')
print(f'Model saved in {save_dir}')
try:
    cache = load_pickle('cache.pkl')
except FileNotFoundError:
    cache = {}
for lang in ['mbert', 'bert']:
    for model, color in zip(['dep', 'sdp', 'ens'], 'rgb'):
        key = f'{lang}-{model}'
        if key in cache:
            xs, ys = cache[key]
        else:
            pred_file = template.replace('bert', lang).replace('dep', model)
            xs = np.arange(5, 50, 5)
            ys = [
                evaluate(limit_len(gold_file, l),
                         limit_len(pred_file, l),
                         do_copy_cols=False,
                         do_enhanced_collapse_empty_nodes=True)['ELAS'].f1
                for l in xs
            ]
            cache[key] = (xs, ys)
        plt.plot(xs,
                 ys,
                 label=key.replace('mbert', 'multilingual').replace(
                     'bert',
                     'language-specific').replace('dep', 'DTP').replace(
                         'sdp', 'DGP').replace('ens', 'ENS'),
                 color=color,
                 linestyle='-' if lang == 'mbert' else '--')
        print(key)
save_pickle(cache, 'cache.pkl')
plt.xlabel('sentence length')
def run(lang, do_train=True, do_eval=True, mbert=True):
    """
    Run training and decoding
    :param lang: Language code, 2 letters.
    :param do_train: Train model or not.
    :param do_eval: Evaluate performance (generating output) or not.
    :param mbert: Use mbert or language specific transformers.
    """
    dataset = f'data/iwpt2020/train-dev-combined/{lang}'
    trnfile = f'{dataset}/train.short.conllu'
    # for idx, sent in enumerate(read_conll(trnfile)):
    #     print(f'\r{idx}', end='')
    devfile = f'{dataset}/dev.short.conllu'
    testfile = f'data/iwpt2020/test-udpipe/{lang}.fixed.short.conllu'
    prefix = 'mbert'
    transformer = 'bert-base-multilingual-cased'
    if not mbert:
        prefix = 'bert'
        if lang == 'sv':
            transformer = "KB/bert-base-swedish-cased"
        if lang == 'ar':
            transformer = "asafaya/bert-base-arabic"
        elif lang == 'en':
            transformer = 'albert-xxlarge-v2'
        elif lang == 'ru':
            transformer = "DeepPavlov/rubert-base-cased"
        elif lang == 'fi':
            transformer = "TurkuNLP/bert-base-finnish-cased-v1"
        elif lang == 'it':
            transformer = "dbmdz/bert-base-italian-cased"
        elif lang == 'nl':
            transformer = "wietsedv/bert-base-dutch-cased"
        elif lang == 'et':
            transformer = get_resource(
                'http://dl.turkunlp.org/estonian-bert/etwiki-bert/pytorch/etwiki-bert-base-cased.tar.gz'
            )
        elif lang == 'fr':
            transformer = 'camembert-base'
        elif lang == 'pl':
            transformer = "dkleczek/bert-base-polish-uncased-v1"
        elif lang == 'sk' or lang == 'bg' or lang == 'cs':
            transformer = get_resource(
                'http://files.deeppavlov.ai/deeppavlov_data/bert/bg_cs_pl_ru_cased_L-12_H-768_A-12_pt.tar.gz'
            )
        else:
            prefix = 'mbert'
    save_dir = f'data/model/iwpt2020/{lang}/{prefix}_dep'
    # if do_train and os.path.isdir(save_dir):
    #     return
    strategy = tf.distribute.MirroredStrategy()
    print("Number of devices: {}".format(strategy.num_replicas_in_sync))
    with strategy.scope():
        parser = BiaffineTransformerDependencyParser(strategy=strategy)
        if do_train:
            parser.fit(
                trnfile,
                devfile,
                save_dir,
                transformer,
                batch_size=4096,
                warmup_steps_ratio=.1,
                samples_per_batch=150,
                # max_samples_per_batch=75,
                transformer_dropout=.33,
                learning_rate=2e-3,
                learning_rate_transformer=1e-5,
                # max_seq_length=512,
                # epochs=1
            )
    logger = init_logger(name='test', root_dir=save_dir, mode='w')
    parser.config.tree = 'mst'
    # dep_dev_output = f'{save_dir}/{os.path.basename(devfile.replace(".conllu", ".dep.pred.conllu"))}'
    # if not os.path.isfile(dep_dev_output) or do_eval:
    #     parser.evaluate(devfile, save_dir, warm_up=False, output=dep_dev_output, logger=logger)
    dep_test_output = f'{save_dir}/{os.path.basename(testfile.replace(".conllu", ".dep.pred.conllu"))}'
    if not os.path.isfile(dep_test_output) or do_eval:
        parser.load(save_dir, tree='mst')
        parser.evaluate(testfile,
                        save_dir,
                        warm_up=False,
                        output=dep_test_output,
                        logger=None)
    # score = evaluate(devfile, dep_dev_output)
    # dep_dev_elas = score["ELAS"].f1
    # dep_dev_clas = score["CLAS"].f1
    # logger.info(f'DEP score for {lang}:')
    # logger.info(f'ELAS: {dep_dev_elas * 100:.2f} - CLAS:{dep_dev_clas * 100:.2f}')
    if do_train:
        print(f'Model saved in {save_dir}')

    save_dir = f'data/model/iwpt2020/{lang}/{prefix}_sdp'
    parser = BiaffineTransformerSemanticDependencyParser()
    if do_train and not os.path.isdir(save_dir):
        parser.fit(
            trnfile,
            devfile,
            save_dir,
            transformer,
            batch_size=1000 if lang == 'cs' else 3000,
            warmup_steps_ratio=.1,
            samples_per_batch=150,
            # max_samples_per_batch=150,
            transformer_dropout=.33,
            learning_rate=2e-3,
            learning_rate_transformer=1e-5,
            # max_seq_length=512,
            # epochs=1
        )
    # (sdp_dev_elas, final_sdp_dev_output), (ensemble_dev_elas, final_ensemble_dev_output) = \
    #     eval_sdp_and_ensemble(parser, devfile, dep_dev_output, save_dir, lang, logger)
    (sdp_test_elas, final_sdp_test_output), (ensemble_test_elas, final_ensemble_test_output) = \
        eval_sdp_and_ensemble(parser, testfile, dep_test_output, save_dir, lang, logger, do_eval)
    save_dir = f'data/model/iwpt2020/{lang}/'
    # copyfile(dep_dev_output, save_dir + 'dev.dep.conllu')
    # copyfile(final_sdp_dev_output, save_dir + 'dev.sdp.conllu')
    # copyfile(final_ensemble_dev_output, save_dir + 'dev.ens.conllu')
    # dev_scores = [dep_dev_elas, sdp_dev_elas, ensemble_dev_elas]
    # winner = max(dev_scores)
    # widx = dev_scores.index(winner)
    dep_test_output = merge_long_sent(dep_test_output)
    evaluate(f'data/iwpt2020/test-udpipe/{lang}.fixed.conllu', dep_test_output)
    dep_test_output = dep_test_output.replace('.conllu', '.fixed.conllu')
    # if widx == 0:
    #     # dep wins, but we don't have output for dep, so let's do it below
    #     best_test_output = dep_test_output
    #     best_task = 'dep'
    # elif widx == 1:
    #     # sdp wins
    #     best_test_output = final_sdp_test_output
    #     best_task = 'sdp'
    # else:
    #     # ensemble wins
    #     best_test_output = final_ensemble_test_output
    #     best_task = 'ens'
    #
    # info = {
    #     'best_task': best_task,
    #     'dev_scores': dict((x, y) for x, y in zip(['dep', 'sdp', 'ens'], dev_scores))
    # }
    # save_json(info, save_dir + 'scores.json')
    # copyfile(best_test_output, save_dir + lang + '.conllu')
    # dev_json = 'data/model/iwpt2020/dev.json'
    # try:
    #     total = load_json(dev_json)
    # except FileNotFoundError:
    #     total = {}
    # total[lang] = info
    # save_json(total, dev_json)

    final_root = f'data/model/iwpt2020/{prefix}'
    dep_root = f'{final_root}/dep'
    sdp_root = f'{final_root}/sdp'
    ens_root = f'{final_root}/ens'
    outputs = [
        dep_test_output, final_sdp_test_output, final_ensemble_test_output
    ]
    folders = [dep_root, sdp_root, ens_root]
    for o, f in zip(outputs, folders):
        os.makedirs(f, exist_ok=True)
        tmp = f'/tmp/{lang}.conllu'
        copyfile(o, tmp)
        remove_complete_edges(tmp, tmp)
        restore_collapse_edges(tmp, tmp)
        conllu_quick_fix(tmp, f'{f}/{lang}.conllu')