Exemple #1
0
def biaffine(model_path, model_name, test_path, punct_set, use_gpu, logger,
             args):
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_data.create_alphabets(alphabet_path, None, data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    decoding = args.decode

    logger.info('use gpu: %s, decoding: %s' % (use_gpu, decoding))

    data_test = conllx_data.read_data_to_tensor(test_path,
                                                word_alphabet,
                                                char_alphabet,
                                                pos_alphabet,
                                                type_alphabet,
                                                use_gpu=use_gpu,
                                                volatile=True,
                                                symbolic_root=True)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    logger.info('model: %s' % model_name)

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()
    network = BiRecurrentConvBiAffine(*args, **kwargs)
    network.load_state_dict(torch.load(model_name))

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0
    test_total = 0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_total_nopunc = 0
    test_total_inst = 0

    test_root_correct = 0.0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    pred_writer.start('tmp/analyze_pred_%s' % str(uid))
    gold_writer.start('tmp/analyze_gold_%s' % str(uid))
    sent = 0
    start_time = time.time()

    for batch in conllx_data.iterate_batch_tensor(data_test, 1):
        sys.stdout.write('%d, ' % sent)
        sys.stdout.flush()
        sent += 1

        word, char, pos, heads, types, masks, lengths = batch
        heads_pred, types_pred = decode(
            word,
            char,
            pos,
            mask=masks,
            length=lengths,
            leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word,
                          pos,
                          heads_pred,
                          types_pred,
                          lengths,
                          symbolic_root=True)
        gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

        stats, stats_nopunc, stats_root, num_inst = parser.eval(
            word,
            pos,
            heads_pred,
            types_pred,
            heads,
            types,
            word_alphabet,
            pos_alphabet,
            lengths,
            punct_set=punct_set,
            symbolic_root=True)
        ucorr, lcorr, total, ucm, lcm = stats
        ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
        corr_root, total_root = stats_root

        test_ucorrect += ucorr
        test_lcorrect += lcorr
        test_total += total
        test_ucomlpete_match += ucm
        test_lcomplete_match += lcm

        test_ucorrect_nopunc += ucorr_nopunc
        test_lcorrect_nopunc += lcorr_nopunc
        test_total_nopunc += total_nopunc
        test_ucomlpete_match_nopunc += ucm_nopunc
        test_lcomplete_match_nopunc += lcm_nopunc

        test_root_correct += corr_root
        test_total_root += total_root

        test_total_inst += num_inst

    pred_writer.close()
    gold_writer.close()

    print('\ntime: %.2fs' % (time.time() - start_time))
    print(
        'test W. Punct:  ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
        %
        (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
         test_total, test_lcorrect * 100 / test_total, test_ucomlpete_match *
         100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst))
    print(
        'test Wo Punct:  ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
        %
        (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
         test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc *
         100 / test_total_nopunc, test_ucomlpete_match_nopunc * 100 /
         test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst))
    print('test Root: corr: %d, total: %d, acc: %.2f%%' %
          (test_root_correct, test_total_root,
           test_root_correct * 100 / test_total_root))
def biaffine(model_path, model_name, test_path, punct_set, use_gpu, logger, args):
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_data.create_alphabets(alphabet_path,
        None, data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)
    # word_alphabet, char_alphabet, pos_alphabet, type_alphabet = create_alphabets(alphabet_path,
    #     None, data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    decoding = args.decode
    out_filename = args.out_filename
    constraints_method = args.constraints_method
    constraintFile = args.constraint_file
    ratioFile = args.ratio_file
    tolerance = args.tolerance
    gamma = args.gamma
    the_language = args.mt_log[9:11]
    mt_log = open(args.mt_log, 'a')
    summary_log = open(args.summary_log, 'a')
    logger.info('use gpu: %s, decoding: %s' % (use_gpu, decoding))

    #
    extra_embeds_arr = augment_with_extra_embedding(word_alphabet, args.extra_embed, args.extra_embed_src, test_path, logger)

    # ===== the reading
    def _read_one(path, is_train):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." % (path, lang_id))
        one_data = conllx_data.read_data_to_variable(path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=(not is_train), symbolic_root=True, lang_id=lang_id)
        return one_data

    data_test = _read_one(test_path, False)

    # data_test = conllx_data.read_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet,
    #                                               use_gpu=use_gpu, volatile=True, symbolic_root=True)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    logger.info('model: %s' % model_name)

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()
    network = BiRecurrentConvBiAffine(use_gpu=use_gpu, *args, **kwargs)
    network.load_state_dict(torch.load(model_name))

    #
    augment_network_embed(word_alphabet.size(), network, extra_embeds_arr)

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()


    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    elif decoding == 'proj':
        decode = network.decode_proj
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    # pred_writer.start('tmp/analyze_pred_%s' % str(uid))
    # gold_writer.start('tmp/analyze_gold_%s' % str(uid))
    # pred_writer.start(model_path + out_filename + '_pred')
    # gold_writer.start(model_path + out_filename + '_gold')
    pred_writer.start(out_filename + '_pred')
    gold_writer.start(out_filename + '_gold')

    sent = 0
    start_time = time.time()

    constraints = []
    
    mt_log.write("=====================%s, Ablation 2================\n"%(constraints_method))
    summary_log.write("==========================%s, Ablation 2=============\n"%(constraints_method))
    if ratioFile == 'WALS':
        import pickle as pk
        cFile = open(constraintFile, 'rb')
        WALS_data = pk.load(cFile)
        for idx in ['85A', '87A', '89A']:
            constraint = Constraint(0,0,0)
            extra_const = constraint.load_WALS(idx, WALS_data[the_language][idx], pos_alphabet, method=constraints_method)
            constraints.append(constraint)
            if extra_const:
                constraints.append(extra_const)
        constraint = Constraint(0,0,0)
        extra_const = constraint.load_WALS_unary(WALS_data[the_language], pos_alphabet, method=constraints_method)
        if extra_const:
            constraints.append(extra_const)
        constraints.append(constraint)
    elif ratioFile == 'None':
        summary_log.write("=================No it is baseline================\n")
        mt_log.write("==================No it is baseline==============\n")
    else:
        cFile = open(constraintFile, 'r')
        for line in cFile:
            if len(line.strip()) < 2:
               break
            pos1, pos2 = line.strip().split('\t')
            constraint = Constraint(0,0,0)
            constraint.load(pos1, pos2, ratioFile, pos_alphabet)
            constraints.append(constraint)
    
    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0
    test_total = 0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_total_nopunc = 0
    test_total_inst = 0

    test_root_correct = 0.0
    test_total_root = 0
    arc_list = []
    type_list = []
    length_list = []
    pos_list = []
    
    for batch in conllx_data.iterate_batch_variable(data_test, 1):
        word, char, pos, heads, types, masks, lengths = batch
        out_arc, out_type, length = network.pretrain_constraint(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
        arc_list += list(out_arc)
        type_list += list(out_type)
        length_list += list(length)
        pos_list += list(pos)
        
    if constraints_method == 'binary':
        train_constraints = network.binary_constraints
    if constraints_method == 'Lagrange':
        train_constraints = network.Lagrange_constraints
    if constraints_method == 'PR':
        train_constraints = network.PR_constraints
    train_constraints(arc_list, type_list, length_list, pos_list, constraints, tolerance, mt_log, gamma=gamma)        

    for batch in conllx_data.iterate_batch_variable(data_test, 1):
        #sys.stdout.write('%d, ' % sent)
        #sys.stdout.flush()
        sent += 1

        word, char, pos, heads, types, masks, lengths = batch
        heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths,
                                        leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS, constraints=constraints, method=constraints_method, gamma=gamma)
        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
        gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

        stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types,
                                                                word_alphabet, pos_alphabet, lengths,
                                                                punct_set=punct_set, symbolic_root=True)
        ucorr, lcorr, total, ucm, lcm = stats
        ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
        corr_root, total_root = stats_root

        test_ucorrect += ucorr
        test_lcorrect += lcorr
        test_total += total
        test_ucomlpete_match += ucm
        test_lcomplete_match += lcm

        test_ucorrect_nopunc += ucorr_nopunc
        test_lcorrect_nopunc += lcorr_nopunc
        test_total_nopunc += total_nopunc
        test_ucomlpete_match_nopunc += ucm_nopunc
        test_lcomplete_match_nopunc += lcm_nopunc

        test_root_correct += corr_root
        test_total_root += total_root

        test_total_inst += num_inst

    print('\ntime: %.2fs' % (time.time() - start_time))
    print('test W. Punct:  ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
        test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total,
        test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst))
    print('test Wo Punct:  ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
        test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
        test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc,
        test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst))
    print('test Root: corr: %d, total: %d, acc: %.2f%%' % (
        test_root_correct, test_total_root, test_root_correct * 100 / test_total_root))
    mt_log.write('uas: %.2f, las: %.2f\n'%(test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc))
    summary_log.write('%s: %.2f %.2f\n'%(the_language, test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc))
    pred_writer.close()
    gold_writer.close()
def run_biaffine(model_path, model_name, test_path, punct_set, use_gpu, logger,
                 args):
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_data.create_alphabets(
        alphabet_path,
        None,
        data_paths=[None, None],
        max_vocabulary_size=50000,
        embedd_dict=None
    )

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    decoding = args.decode

    logger.info('use gpu: %s, decoding: %s' % (use_gpu, decoding))

    device = torch.device('cuda') if use_gpu else torch.device('cpu')

    data_test = aida_data.read_data_to_tensor(test_path,
                                              word_alphabet,
                                              char_alphabet,
                                              pos_alphabet,
                                              type_alphabet,
                                              symbolic_root=True,
                                              device=device)

    pred_writer = AIDAWriter(word_alphabet, char_alphabet, pos_alphabet,
                             type_alphabet)

    logger.info('model: %s' % model_name)

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    model_args, model_kwargs = load_model_arguments_from_json()
    network = BiRecurrentConvBiAffine(*model_args, **model_kwargs)
    if torch.cuda.is_available():
        map_location = lambda storage, loc: storage.cuda()
    else:
        map_location = 'cpu'
    network.load_state_dict(torch.load(model_name, map_location=map_location))

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    pred_writer.start(args.output_path)
    sent = 0
    start_time = time.time()

    with torch.no_grad():
        for batch in aida_data.iterate_batch_tensor(data_test, 1):
            sys.stdout.write('Processing sentence: %d\n' % sent)
            sys.stdout.flush()
            sent += 1

            word, char, pos, _, _, masks, lengths, segment_ids_words = batch
            heads_pred, types_pred = decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            segment_ids, segment_words = zip(*segment_ids_words)
            pred_writer.write(segment_ids,
                              segment_words,
                              word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)

        pred_writer.close()

    print('\ntime: %.2fs' % (time.time() - start_time))
def biaffine(model_path, model_name, pre_model_path, pre_model_name, use_gpu, logger, args):
    alphabet_path = os.path.join(pre_model_path, 'alphabets/')
    logger.info("Alphabet Path: %s" % alphabet_path)
    pre_model_name = os.path.join(pre_model_path, pre_model_name)
    model_name = os.path.join(model_path, model_name)

    # Load pre-created alphabets
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_data.create_alphabets(
        alphabet_path, None, data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info('use gpu: %s' % (use_gpu))

    if args.test_lang:
        extra_embed = args.embed_dir + ("wiki.multi.%s.vec" % args.test_lang)
        extra_word_dict, _ = load_embedding_dict('word2vec', extra_embed)
        test_path = args.data_dir + args.test_lang + '_test.conllu'
        extra_embeds_arr = augment_with_extra_embedding(word_alphabet, extra_word_dict, test_path, logger)
    else:
        extra_embeds_arr = []
        for language in args.langs:
            extra_embed = args.embed_dir + ("wiki.multi.%s.vec" % language)
            extra_word_dict, _ = load_embedding_dict('word2vec', extra_embed)

            test_path = args.data_dir + language + '_train.conllu'
            embeds_arr1 = augment_with_extra_embedding(word_alphabet, extra_word_dict, test_path, logger)
            test_path = args.data_dir + language + '_dev.conllu'
            embeds_arr2 = augment_with_extra_embedding(word_alphabet, extra_word_dict, test_path, logger)
            test_path = args.data_dir + language + '_test.conllu'
            embeds_arr3 = augment_with_extra_embedding(word_alphabet, extra_word_dict, test_path, logger)
            extra_embeds_arr.extend(embeds_arr1 + embeds_arr2 + embeds_arr3)

    # ------------------------------------------------------------------------- #
    # --------------------- Loading model ------------------------------------- #

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = pre_model_name + '.arg.json'
    margs, kwargs = load_model_arguments_from_json()
    network = BiRecurrentConvBiAffine(use_gpu=use_gpu, *margs, **kwargs)
    network.load_state_dict(torch.load(pre_model_name))
    args.use_bert = kwargs.get('use_bert', False)

    #
    augment_network_embed(word_alphabet.size(), network, extra_embeds_arr)

    network.eval()
    logger.info('model: %s' % pre_model_name)

    # Freeze the network
    for p in network.parameters():
        p.requires_grad = False

    nclass = args.nclass
    classifier = nn.Sequential(
        nn.Linear(network.encoder.output_dim, 512),
        nn.Linear(512, nclass)
    )

    if use_gpu:
        network.cuda()
        classifier.cuda()
    else:
        network.cpu()
        classifier.cpu()

    batch_size = args.batch_size

    # ===== the reading
    def _read_one(path, is_train=False, max_size=None):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." % (path, lang_id))
        one_data = conllx_data.read_data_to_variable(path, word_alphabet, char_alphabet, pos_alphabet,
                                                     type_alphabet, use_gpu=use_gpu, volatile=(not is_train),
                                                     use_bert=args.use_bert, symbolic_root=True, lang_id=lang_id,
                                                     max_size=max_size)
        return one_data

    def compute_accuracy(data, lang_idx):
        total_corr, total = 0, 0
        classifier.eval()
        with torch.no_grad():
            for batch in conllx_data.iterate_batch_variable(data, batch_size):
                word, char, pos, _, _, masks, lengths, bert_inputs = batch
                if use_gpu:
                    word = word.cuda()
                    char = char.cuda()
                    pos = pos.cuda()
                    masks = masks.cuda()
                    lengths = lengths.cuda()
                    if bert_inputs[0] is not None:
                        bert_inputs[0] = bert_inputs[0].cuda()
                        bert_inputs[1] = bert_inputs[1].cuda()
                        bert_inputs[2] = bert_inputs[2].cuda()

                output = network.forward(word, char, pos, input_bert=bert_inputs,
                                         mask=masks, length=lengths, hx=None)
                output = output['output'].detach()

                if args.train_level == 'word':
                    output = classifier(output)
                    output = output.contiguous().view(-1, output.size(2))
                else:
                    output = torch.mean(output, dim=1)
                    output = classifier(output)

                preds = output.max(1)[1].cpu()
                labels = torch.LongTensor([lang_idx])
                labels = labels.expand(*preds.size())
                n_correct = preds.eq(labels).sum().item()
                total_corr += n_correct
                total += output.size(0)

            return {'total_corr': total_corr, 'total': total}

    if args.test_lang:
        classifier.load_state_dict(torch.load(model_name))
        path = args.data_dir + args.test_lang + '_train.conllu'
        test_data = _read_one(path)

        # TODO: fixed indexing is not GOOD
        lang_idx = 0 if args.test_lang == args.src_lang else 1
        result = compute_accuracy(test_data, lang_idx)
        accuracy = (result['total_corr'] * 100.0) / result['total']
        logger.info('[Classifier performance] Language: %s || accuracy: %.2f%%' % (args.test_lang, accuracy))

    else:
        # if output directory doesn't exist, create it
        if not os.path.exists(args.model_path):
            os.makedirs(args.model_path)

        # --------------------- Loading data -------------------------------------- #
        train_data = dict()
        dev_data = dict()
        test_data = dict()
        num_data = dict()
        lang_ids = dict()
        reverse_lang_ids = dict()

        # loading language data
        for language in args.langs:
            lang_ids[language] = len(lang_ids)
            reverse_lang_ids[lang_ids[language]] = language

            train_path = args.data_dir + language + '_train.conllu'
            # Utilize at most 10000 examples
            tmp_data = _read_one(train_path, max_size=10000)
            num_data[language] = sum(tmp_data[1])
            train_data[language] = tmp_data

            dev_path = args.data_dir + language + '_dev.conllu'
            tmp_data = _read_one(dev_path)
            dev_data[language] = tmp_data

            test_path = args.data_dir + language + '_test.conllu'
            tmp_data = _read_one(test_path)
            test_data[language] = tmp_data

        # ------------------------------------------------------------------------- #

        optim = torch.optim.Adam(classifier.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        def compute_loss(lang_name, land_idx):
            word, char, pos, _, _, masks, lengths, bert_inputs = conllx_data.get_batch_variable(train_data[lang_name],
                                                                                                batch_size,
                                                                                                unk_replace=0.5)

            if use_gpu:
                word = word.cuda()
                char = char.cuda()
                pos = pos.cuda()
                masks = masks.cuda()
                lengths = lengths.cuda()
                if bert_inputs[0] is not None:
                    bert_inputs[0] = bert_inputs[0].cuda()
                    bert_inputs[1] = bert_inputs[1].cuda()
                    bert_inputs[2] = bert_inputs[2].cuda()

            output = network.forward(word, char, pos, input_bert=bert_inputs,
                                     mask=masks, length=lengths, hx=None)
            output = output['output'].detach()

            if args.train_level == 'word':
                output = classifier(output)
                output = output.contiguous().view(-1, output.size(2))
            else:
                output = torch.mean(output, dim=1)
                output = classifier(output)

            labels = torch.empty(output.size(0)).fill_(land_idx).type_as(output).long()
            loss = criterion(output, labels)
            return loss

        # ---------------------- Form the mini-batches -------------------------- #
        num_batches = 0
        batch_lang_labels = []
        for lang in args.langs:
            nbatches = num_data[lang] // batch_size + 1
            batch_lang_labels.extend([lang] * nbatches)
            num_batches += nbatches

        assert len(batch_lang_labels) == num_batches
        # ------------------------------------------------------------------------- #

        best_dev_accuracy = 0
        patience = 0
        for epoch in range(1, args.num_epochs + 1):
            # shuffling the data
            lang_in_batch = copy.copy(batch_lang_labels)
            random.shuffle(lang_in_batch)

            classifier.train()
            for batch in range(1, num_batches + 1):
                lang_name = lang_in_batch[batch - 1]
                lang_id = lang_ids.get(lang_name)

                loss = compute_loss(lang_name, lang_id)
                loss.backward()
                optim.step()

            # Validation
            avg_acc = dict()
            for dev_lang in dev_data.keys():
                lang_idx = lang_ids.get(dev_lang)
                result = compute_accuracy(dev_data[dev_lang], lang_idx)
                accuracy = (result['total_corr'] * 100.0) / result['total']
                avg_acc[dev_lang] = accuracy

            acc = ', '.join('%s: %.2f' % (key, val) for (key, val) in avg_acc.items())
            logger.info('Epoch: %d, Performance[%s]' % (epoch, acc))

            avg_acc = sum(avg_acc.values()) / len(avg_acc)
            if best_dev_accuracy < avg_acc:
                best_dev_accuracy = avg_acc
                patience = 0
                state_dict = classifier.state_dict()
                torch.save(state_dict, model_name)
            else:
                patience += 1

            if patience >= 5:
                break

        # Testing
        logger.info('Testing model %s' % pre_model_name)
        total_corr, total = 0, 0
        for test_lang in UD_languages:
            if test_lang in test_data:
                lang_idx = lang_ids.get(test_lang)
                result = compute_accuracy(test_data[test_lang], lang_idx)
                accuracy = (result['total_corr'] * 100.0) / result['total']
                print('[LANG]: %s, [ACC]: %.2f' % (test_lang.upper(), accuracy))
                total_corr += result['total_corr']
                total += result['total']
        print('[Avg. Performance]: %.2f' % ((total_corr * 100.0) / total))