示例#1
0
                             num_workers=args.num_workers,
                             drop_last=True)

    _, num_class_test = test_dataset.get_class_weight()
    print('\nNumber of testing samples: ' + str(test_dataset.__len__()))
    for i, c in enumerate(num_class_test):
        print("\tLabel {:d}:".format(i).ljust(15) + "{:d}".format(c).rjust(8))

    args.num_features = len(test_dataset.alphabet)
    model = CharCNN(args)
    print("=> loading weights from '{}'".format(args.model_path))
    assert os.path.isfile(
        args.model_path), "=> no checkpoint found at '{}'".format(
            args.model_path)
    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['state_dict'])

    # using GPU
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    model.eval()
    corrects, avg_loss, accumulated_loss, size = 0, 0, 0, 0
    predicates_all, target_all = [], []
    print('\nTesting...')
    for i_batch, (data) in enumerate(test_loader):
        inputs, target = data
        if args.target_sub_scaler != 0:
            target.sub_(args.target_sub_scaler)
        size += len(target)
        if args.cuda:
示例#2
0
crf = CRF(label_size=len(label_vocab) + 2)
linear = Linear(in_features=lstm.output_size, out_features=len(label_vocab))
lstm_crf = LstmCrf(token_vocab,
                   label_vocab,
                   char_vocab,
                   word_embedding=word_embed,
                   char_embedding=char_embed,
                   crf=crf,
                   lstm=lstm,
                   univ_fc_layer=linear,
                   embed_dropout_prob=train_args['feat_dropout'],
                   lstm_dropout_prob=train_args['lstm_dropout'],
                   char_highway=char_hw if train_args['use_highway'] else None)

word_embed.load_state_dict(state['model']['word_embed'])
char_embed.load_state_dict(state['model']['char_embed'])
char_hw.load_state_dict(state['model']['char_hw'])
lstm.load_state_dict(state['model']['lstm'])
crf.load_state_dict(state['model']['crf'])
linear.load_state_dict(state['model']['linear'])
lstm_crf.load_state_dict(state['model']['lstm_crf'])

if use_gpu:
    lstm_crf.cuda()

# Load dataset
logger.info('Loading data')
parser = ConllParser()
test_set = SeqLabelDataset(data_file, parser=parser)
test_set.numberize(token_vocab, label_vocab, char_vocab)
idx_token = {v: k for k, v in token_vocab.items()}
from model import CharCNN
import torch

model = CharCNN(70, 0.5)
model.load_state_dict(torch.load('save_model/best.pt'))

sent = "U.S. Brokers Cease-fire in Western Afghanistan KABUL (Reuters) - The United States has brokered a  cease-fire between a renegade Afghan militia leader and the  embattled governor of the western province of Herat,  Washington's envoy to Kabul said Tuesday."
sent_tensor = torch.zeros(1014).long()
alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}\n"
for i, char in enumerate(sent):
    if i == 1014:
        break
    alphabet_index = alphabet.find(char)
    if alphabet_index != -1:
        sent_tensor[i] = alphabet_index

sent_tensor = sent_tensor.view(-1, sent_tensor.size(0))
out_feature = model(sent_tensor)
out_feature = out_feature.squeeze(0)
print('out_feature:', out_feature)
示例#4
0
lstm = LSTM(
    Config({
        'input_size': word_embed.output_size + char_cnn.output_size,
        'hidden_size': train_args['lstm_hidden_size'],
        'forget_bias': 1.0,
        'batch_first': True,
        'bidirectional': True
    }))
crf = CRF(Config({'label_vocab': label_vocab}))
output_linear = Linear(
    Config({
        'in_features': lstm.output_size,
        'out_features': len(label_vocab)
    }))
word_embed.load_state_dict(state['model']['word_embed'])
char_cnn.load_state_dict(state['model']['char_cnn'])
char_highway.load_state_dict(state['model']['char_highway'])
lstm.load_state_dict(state['model']['lstm'])
crf.load_state_dict(state['model']['crf'])
output_linear.load_state_dict(state['model']['output_linear'])
lstm_crf = LstmCrf(token_vocab=token_vocab,
                   label_vocab=label_vocab,
                   char_vocab=char_vocab,
                   word_embedding=word_embed,
                   char_embedding=char_cnn,
                   crf=crf,
                   lstm=lstm,
                   univ_fc_layer=output_linear,
                   embed_dropout_prob=train_args['embed_dropout'],
                   lstm_dropout_prob=train_args['lstm_dropout'],
                   linear_dropout_prob=train_args['linear_dropout'],