示例#1
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger,
             args):
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_stacked_data.create_alphabets(alphabet_path, None, data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    display_inst = args.display

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' %
                (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_tensor(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    logger.info('model: %s' % model_name)
    network = StackPtrNet(*args, **kwargs)
    network.load_state_dict(torch.load(model_name))

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0
    test_total = 0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_total_nopunc = 0
    test_total_inst = 0

    test_root_correct = 0.0
    test_total_root = 0

    test_ucorrect_stack_leaf = 0.0
    test_ucorrect_stack_non_leaf = 0.0

    test_lcorrect_stack_leaf = 0.0
    test_lcorrect_stack_non_leaf = 0.0

    test_leaf = 0
    test_non_leaf = 0

    pred_writer.start('tmp/analyze_pred_%s' % str(uid))
    gold_writer.start('tmp/analyze_gold_%s' % str(uid))
    sent = 0
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(
            data_test, 1):
        sys.stdout.write('%d, ' % sent)
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks, lengths = input_encoder
        stacked_heads, children, siblings, stacked_types, skip_connect, mask_d, lengths_d = input_decoder
        heads_pred, types_pred, children_pred, stacked_types_pred = network.decode(
            word,
            char,
            pos,
            mask=masks,
            length=lengths,
            beam=beam,
            ordered=ordered,
            leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()
        mask_d = mask_d.data
        mask_leaf = torch.eq(children, stacked_heads).float()
        mask_non_leaf = (1.0 - mask_leaf)
        mask_leaf = mask_leaf * mask_d
        mask_non_leaf = mask_non_leaf * mask_d
        num_leaf = mask_leaf.sum()
        num_non_leaf = mask_non_leaf.sum()

        ucorr_stack = torch.eq(children_pred, children).float()
        lcorr_stack = ucorr_stack * torch.eq(stacked_types_pred,
                                             stacked_types).float()
        ucorr_stack_leaf = (ucorr_stack * mask_leaf).sum()
        ucorr_stack_non_leaf = (ucorr_stack * mask_non_leaf).sum()

        lcorr_stack_leaf = (lcorr_stack * mask_leaf).sum()
        lcorr_stack_non_leaf = (lcorr_stack * mask_non_leaf).sum()

        test_ucorrect_stack_leaf += ucorr_stack_leaf
        test_ucorrect_stack_non_leaf += ucorr_stack_non_leaf
        test_lcorrect_stack_leaf += lcorr_stack_leaf
        test_lcorrect_stack_non_leaf += lcorr_stack_non_leaf

        test_leaf += num_leaf
        test_non_leaf += num_non_leaf

        # ------------------------------------------------------------------------------------------------

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word,
                          pos,
                          heads_pred,
                          types_pred,
                          lengths,
                          symbolic_root=True)
        gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

        stats, stats_nopunc, stats_root, num_inst = parser.eval(
            word,
            pos,
            heads_pred,
            types_pred,
            heads,
            types,
            word_alphabet,
            pos_alphabet,
            lengths,
            punct_set=punct_set,
            symbolic_root=True)
        ucorr, lcorr, total, ucm, lcm = stats
        ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
        corr_root, total_root = stats_root

        test_ucorrect += ucorr
        test_lcorrect += lcorr
        test_total += total
        test_ucomlpete_match += ucm
        test_lcomplete_match += lcm

        test_ucorrect_nopunc += ucorr_nopunc
        test_lcorrect_nopunc += lcorr_nopunc
        test_total_nopunc += total_nopunc
        test_ucomlpete_match_nopunc += ucm_nopunc
        test_lcomplete_match_nopunc += lcm_nopunc

        test_root_correct += corr_root
        test_total_root += total_root

        test_total_inst += num_inst

    pred_writer.close()
    gold_writer.close()

    print('\ntime: %.2fs' % (time.time() - start_time))

    print(
        'test W. Punct:  ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
        %
        (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
         test_total, test_lcorrect * 100 / test_total, test_ucomlpete_match *
         100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst))
    print(
        'test Wo Punct:  ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
        %
        (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
         test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc *
         100 / test_total_nopunc, test_ucomlpete_match_nopunc * 100 /
         test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst))
    print('test Root: corr: %d, total: %d, acc: %.2f%%' %
          (test_root_correct, test_total_root,
           test_root_correct * 100 / test_total_root))
    print(
        '============================================================================================================================'
    )

    print(
        'Stack leaf:     ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%'
        % (test_ucorrect_stack_leaf, test_lcorrect_stack_leaf, test_leaf,
           test_ucorrect_stack_leaf * 100 / test_leaf,
           test_lcorrect_stack_leaf * 100 / test_leaf))
    print(
        'Stack non_leaf: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%'
        % (test_ucorrect_stack_non_leaf, test_lcorrect_stack_non_leaf,
           test_non_leaf, test_ucorrect_stack_non_leaf * 100 / test_non_leaf,
           test_lcorrect_stack_non_leaf * 100 / test_non_leaf))
    print(
        '============================================================================================================================'
    )

    def analyze():
        np.set_printoptions(linewidth=100000)
        pred_path = 'tmp/analyze_pred_%s' % str(uid)
        data_gold = conllx_stacked_data.read_stacked_data_to_tensor(
            test_path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=use_gpu,
            volatile=True,
            prior_order=prior_order)
        data_pred = conllx_stacked_data.read_stacked_data_to_tensor(
            pred_path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=use_gpu,
            volatile=True,
            prior_order=prior_order)

        gold_iter = conllx_stacked_data.iterate_batch_stacked_variable(
            data_gold, 1)
        test_iter = conllx_stacked_data.iterate_batch_stacked_variable(
            data_pred, 1)
        model_err = 0
        search_err = 0
        type_err = 0
        for gold, pred in zip(gold_iter, test_iter):
            gold_encoder, gold_decoder = gold
            word, char, pos, gold_heads, gold_types, masks, lengths = gold_encoder
            gold_stacked_heads, gold_children, gold_siblings, gold_stacked_types, gold_skip_connect, gold_mask_d, gold_lengths_d = gold_decoder

            pred_encoder, pred_decoder = pred
            _, _, _, pred_heads, pred_types, _, _ = pred_encoder
            pred_stacked_heads, pred_children, pred_siblings, pred_stacked_types, pred_skip_connect, pred_mask_d, pred_lengths_d = pred_decoder

            assert gold_heads.size() == pred_heads.size(
            ), 'sentence dis-match.'

            ucorr_stack = torch.eq(pred_children, gold_children).float()
            lcorr_stack = ucorr_stack * torch.eq(pred_stacked_types,
                                                 gold_stacked_types).float()
            ucorr_stack = (ucorr_stack * gold_mask_d).data.sum()
            lcorr_stack = (lcorr_stack * gold_mask_d).data.sum()
            num_stack = gold_mask_d.data.sum()

            if lcorr_stack < num_stack:
                loss_pred, loss_pred_arc, loss_pred_type = calc_loss(
                    network, word, char, pos, pred_heads, pred_stacked_heads,
                    pred_children, pred_siblings, pred_stacked_types,
                    pred_skip_connect, masks, lengths, pred_mask_d,
                    pred_lengths_d)

                loss_gold, loss_gold_arc, loss_gold_type = calc_loss(
                    network, word, char, pos, gold_heads, gold_stacked_heads,
                    gold_children, gold_siblings, gold_stacked_types,
                    gold_skip_connect, masks, lengths, gold_mask_d,
                    gold_lengths_d)

                if display_inst:
                    print('%d, %d, %d' % (ucorr_stack, lcorr_stack, num_stack))
                    print(
                        'pred(arc, type): %.4f (%.4f, %.4f), gold(arc, type): %.4f (%.4f, %.4f)'
                        % (loss_pred, loss_pred_arc, loss_pred_type, loss_gold,
                           loss_gold_arc, loss_gold_type))
                    word = word[0].data.cpu().numpy()
                    pos = pos[0].data.cpu().numpy()
                    head_gold = gold_heads[0].data.cpu().numpy()
                    type_gold = gold_types[0].data.cpu().numpy()
                    head_pred = pred_heads[0].data.cpu().numpy()
                    type_pred = pred_types[0].data.cpu().numpy()
                    display(word, pos, head_gold, type_gold, head_pred,
                            type_pred, lengths[0], word_alphabet, pos_alphabet,
                            type_alphabet)

                    length_dec = gold_lengths_d[0]
                    gold_display = np.empty([3, length_dec])
                    gold_display[0] = gold_stacked_types.data[0].cpu().numpy(
                    )[:length_dec]
                    gold_display[1] = gold_children.data[0].cpu().numpy(
                    )[:length_dec]
                    gold_display[2] = gold_stacked_heads.data[0].cpu().numpy(
                    )[:length_dec]
                    print(gold_display)
                    print(
                        '--------------------------------------------------------'
                    )
                    pred_display = np.empty([3,
                                             pred_lengths_d[0]])[:length_dec]
                    pred_display[0] = pred_stacked_types.data[0].cpu().numpy(
                    )[:length_dec]
                    pred_display[1] = pred_children.data[0].cpu().numpy(
                    )[:length_dec]
                    pred_display[2] = pred_stacked_heads.data[0].cpu().numpy(
                    )[:length_dec]
                    print(pred_display)
                    print(
                        '========================================================'
                    )
                    raw_input()

                if ucorr_stack == num_stack:
                    type_err += 1
                elif loss_pred < loss_gold:
                    model_err += 1
                else:
                    search_err += 1
        print('type   errors: %d' % type_err)
        print('model  errors: %d' % model_err)
        print('search errors: %d' % search_err)

    analyze()
示例#2
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger,
             args):
    pos_embedding = args.pos_embedding
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(
        alphabet_path,
        None,
        pos_embedding,
        data_paths=[None, None],
        max_vocabulary_size=50000,
        embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    display_inst = args.display

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' %
                (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        pos_embedding,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order,
        is_test=True)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)

    logger.info('model: %s' % model_name)
    # kwargs???�로??embedidng 추�?
    word_path = os.path.join(model_path, 'embedding.txt')
    word_dict, word_dim = utils.load_embedding_dict('NNLM', word_path)

    def get_embedding_table():
        table = np.empty([len(word_dict), word_dim])
        for idx, (word, embedding) in enumerate(word_dict.items()):
            try:
                table[idx, :] = embedding
            except:
                print(word)
        return torch.from_numpy(table)

    word_table = get_embedding_table()
    kwargs['embedd_word'] = word_table
    args[1] = len(word_dict)  # word_dim
    network = StackPtrNet(*args, **kwargs)
    # word_embedidng?� ??불러?�기
    model_dict = network.state_dict()
    pretrained_dict = torch.load(model_name)
    model_dict.update({
        k: v
        for k, v in pretrained_dict.items() if k != 'word_embedd.weight'
    })

    network.load_state_dict(model_dict)

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    if not ordered:
        pred_writer.start(model_path + '/tmp/inference.txt')
    else:
        pred_writer.start(model_path + '/tmp/inference_ordered_temp.txt')
    sent = 0
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(
            data_test, 1, pos_embedding, type='dev'):
        sys.stdout.write('%d, ' % sent)
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks, lengths = input_encoder
        stacked_heads, children, siblings, stacked_types, skip_connect, mask_d, lengths_d = input_decoder
        heads_pred, types_pred, children_pred, stacked_types_pred = network.decode(
            word,
            char,
            pos,
            mask=masks,
            length=lengths,
            beam=beam,
            ordered=ordered,
            leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word,
                          pos,
                          heads_pred,
                          types_pred,
                          lengths,
                          symbolic_root=True)
    pred_writer.close()
示例#3
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger,
             args):
    pos_embedding = args.pos_embedding
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_stacked_data.create_alphabets(alphabet_path, None, pos_embedding,data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    display_inst = args.display

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' %
                (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        pos_embedding,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order,
        is_test=True)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)
    #gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    logger.info('model: %s' % model_name)
    network = StackPtrNet(*args, **kwargs)
    network.load_state_dict(torch.load(model_name))

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    pred_writer.start(model_path + 'tmp/analyze_pred')
    sent = 0
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(
            data_test, 1):
        sys.stdout.write('%d, ' % sent)
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks, lengths = input_encoder
        stacked_heads, children, siblings, stacked_types, skip_connect, mask_d, lengths_d = input_decoder
        heads_pred, types_pred, children_pred, stacked_types_pred = network.decode(
            word,
            char,
            pos,
            mask=masks,
            length=lengths,
            beam=beam,
            ordered=ordered,
            leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()
        mask_d = mask_d.data
        mask_leaf = torch.eq(children, stacked_heads).float()
        mask_non_leaf = (1.0 - mask_leaf)
        mask_leaf = mask_leaf * mask_d
        mask_non_leaf = mask_non_leaf * mask_d
        num_leaf = mask_leaf.sum()
        num_non_leaf = mask_non_leaf.sum()

        # ------------------------------------------------------------------------------------------------

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word,
                          pos,
                          heads_pred,
                          types_pred,
                          lengths,
                          symbolic_root=True)

    pred_writer.close()
示例#4
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger, args):
    pos_embedding = args.pos_embedding
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets\
        (alphabet_path,None, pos_embedding,data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    use_bert = args.bert
    bert_path = args.bert_path
    bert_feature_dim = args.bert_feature_dim
    if use_bert:
        etri_test_path = args.etri_test
    else:
        etri_test_path = None

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' % (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding,
                                                                  use_gpu=use_gpu, volatile=True, prior_order=prior_order, is_test=False,
                                                                  bert=use_bert, etri_path=etri_test_path)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding)

    logger.info('model: %s' % model_name)
    word_path = os.path.join(model_path, 'embedding.txt')
    word_dict, word_dim = utils.load_embedding_dict('NNLM', word_path)
    def get_embedding_table():
        table = np.empty([len(word_dict), word_dim])
        for idx,(word, embedding) in enumerate(word_dict.items()):
            try:
                table[idx, :] = embedding
            except:
                print(word)
        return torch.from_numpy(table)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in list(word_alphabet.items()):
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    # word_table = get_embedding_table()
    word_table = construct_word_embedding_table()
    # kwargs['embedd_word'] = word_table
    # args[1] = len(word_dict) # word_dim

    network = StackPtrNet(*args, **kwargs, bert=use_bert, bert_path=bert_path, bert_feature_dim=bert_feature_dim)
    network.load_state_dict(torch.load(model_name))
    """
    model_dict = network.state_dict()
    pretrained_dict = torch.load(model_name)
    model_dict.update({k:v for k,v in list(pretrained_dict.items())
        if k != 'word_embedd.weight'})
    
    network.load_state_dict(model_dict)
    """

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    if not ordered:
        pred_writer.start(model_path + '/inference.txt')
    else:
        pred_writer.start(model_path + '/RL_B[test].txt')
    sent = 1

    dev_ucorr_nopunc = 0.0
    dev_lcorr_nopunc = 0.0
    dev_total_nopunc = 0
    dev_ucomlpete_nopunc = 0.0
    dev_lcomplete_nopunc = 0.0
    dev_total_inst = 0.0
    sys.stdout.write('Start!\n')
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_test, 1, pos_embedding, type='dev', bert=use_bert):
        if sent % 100 == 0:
            ####
            print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
                dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
                dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst,
                dev_lcomplete_nopunc * 100 / dev_total_inst))
            sys.stdout.write('[%d/%d]\n' %(sent, int(data_test[2][0])))
            ####
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks_e, lengths, word_bert = input_encoder
        stacked_heads, children, sibling, stacked_types, skip_connect, previous, nexts, masks_d, lengths_d = input_decoder
        heads_pred, types_pred, _, _ = network.decode(word, char, pos, previous, nexts, stacked_heads, mask_e=masks_e, mask_d=masks_d,
                                                              length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS, input_word_bert=word_bert)
        """
        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()
        """

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.test_write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
###########
        stats, stats_nopunc, _, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types,
                                                                word_alphabet, pos_alphabet, lengths,
                                                                punct_set=punct_set, symbolic_root=True)

        ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
        dev_ucorr_nopunc += ucorr_nopunc
        dev_lcorr_nopunc += lcorr_nopunc
        dev_total_nopunc += total_nopunc
        dev_ucomlpete_nopunc += ucm_nopunc
        dev_lcomplete_nopunc += lcm_nopunc

        dev_total_inst += num_inst
    end_time = time.time()
################
    pred_writer.close()

    print('\nFINISHED!!\n', end_time - start_time)
    print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
        dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
        dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst,
        dev_lcomplete_nopunc * 100 / dev_total_inst))