Ejemplo n.º 1
0
def own_test(text_field, label_field):
    test_iter = load_test_dataset(text_field, label_field)

    args.embedding_dim = text_field.vocab.vectors.size()[-1]
    args.vectors = text_field.vocab.vectors
    args.vocabulary_size = len(text_field.vocab)
    args.class_num = 3
    args.cuda = args.device != -1 and torch.cuda.is_available()
    # print('Parameters:')
    # for attr, value in sorted(args.__dict__.items()):
    #     if attr in {'vectors'}:
    #         continue
    #     print('\t{}={}'.format(attr.upper(), value))

    text_cnn = model.TextCNN(args)
    if args.snapshot:
        print('\nLoading model from {}...\n'.format(args.snapshot))
        text_cnn.load_state_dict(torch.load(args.snapshot))
    else:
        text_cnn.load_state_dict(torch.load('model.pth'))

    if args.cuda:
        device = torch.device("cuda", args.device)
        text_cnn = text_cnn.to(device)
    try:
        train.test(test_iter, text_cnn, args)
    except KeyboardInterrupt:
        print('Exiting from training early')
Ejemplo n.º 2
0
def train_test(text_field, label_field, test=False):
    train_iter, dev_iter, test_iter = load_dataset(text_field,
                                                   label_field,
                                                   args,
                                                   device=-1,
                                                   repeat=False,
                                                   shuffle=True)

    args.vocabulary_size = len(text_field.vocab)
    args.embedding_dim = text_field.vocab.vectors.size()[-1]
    args.vectors = text_field.vocab.vectors

    args.class_num = len(label_field.vocab)
    args.cuda = args.device != -1 and torch.cuda.is_available()
    print('Parameters:')
    for attr, value in sorted(args.__dict__.items()):
        if attr in {'vectors'}:
            continue
        print('\t{}={}'.format(attr.upper(), value))

    text_cnn = model.TextCNN(args)
    if args.snapshot:
        print('\nLoading model from {}...\n'.format(args.snapshot))
        text_cnn.load_state_dict(torch.load(args.snapshot))

    if args.cuda:
        device = torch.device("cuda", args.device)
        text_cnn = text_cnn.to(device)
    if args.test:
        try:
            train.eval(test_iter, text_cnn, args, True)
        except KeyboardInterrupt:
            print('Exiting from testing early')
    else:
        try:
            train.train(train_iter, dev_iter, text_cnn, args)
        except KeyboardInterrupt:
            print('Exiting from training early')
Ejemplo n.º 3
0
# -*- coding:utf-8 -*-
import sys, os
import torch
import model

args = torch.load('snapshot/args.pkl')
text_field = torch.load(os.path.join(args.save_dir, 'textfield.pkl'))
label_field = torch.load(os.path.join(args.save_dir, 'labelfield.pkl'))
model = model.TextCNN(args)
model.load_state_dict(
    torch.load(os.path.join(args.save_dir, 'model_params.pkl')))
if args.cuda:
    torch.cuda.set_device(args.device)
    model = model.cuda()

model.eval()


def predict(content, title=None):
    if title:
        title = [text_field.preprocess(title)]
        title = text_field.process(title)
    content = [text_field.preprocess(content)]
    content = text_field.process(content)

    if title is not None:
        net = torch.cat([title, content], dim=1)
    else:
        net = content

    if args.cuda:
Ejemplo n.º 4
0
    #TEXT.vocab.vectors: [332909, 300]
    args.vectors = TEXT.vocab.vectors

if args.multichannel:
    args.static = True
    args.nonStatic = True

###print parameters
print('Parameters:')
for attr, value in sorted(args.__dict__.items()):
    if attr in {'vectors'}:
        continue
    print('\t{}={}'.format(attr.upper(), value))

###train
textCNN = model.TextCNN(args)

#print('args.vectors ', type(args.vectors))
#print(args.vectors)
#embeddings = np.random.random((64, 432))
#embeddings = np.asarray(embeddings, dtype=int)
#embeddings = torch.from_numpy(embeddings)
#print(input)
#with SummaryWriter(log_dir='./visualLog', comment='TextCNN') as writer:
#    writer.add_graph(textCNN, (embeddings,))
#print(textCNN)

if args.modelLoadFilename:
    print('\nLoading model from {}...\n'.format(args.modelLoadFilename))
    textCNN.load_state_dict(torch.load(args.modelLoadFilename))
Ejemplo n.º 5
0
args.vocabulary_size = len(text_field.vocab)
if args.static:
    args.embedding_dim = text_field.vocab.vectors.size()[-1]
    args.vectors = text_field.vocab.vectors
if args.multichannel:
    args.static = True
    args.non_static = True
args.class_num = len(label_field.vocab)
args.cuda = args.device != -1 and torch.cuda.is_available()
args.filter_sizes = [int(size) for size in args.filter_sizes.split(',')]

print('Parameters:')
for attr, value in sorted(args.__dict__.items()):
    if attr in {'vectors'}:
        continue
    print('\t{}={}'.format(attr.upper(), value))

text_cnn = model.TextCNN(args)
if args.snapshot:
    print('\nLoading model from {}...\n'.format(args.snapshot))
    text_cnn.load_state_dict(torch.load(args.snapshot))

if args.cuda:
    torch.device(args.device)
    # torch.cuda.set_device(args.device)
    text_cnn = text_cnn.cuda()
try:
    train.train(train_iter, dev_iter, text_cnn, args)
except KeyboardInterrupt:
    print('Exiting from training early')
Ejemplo n.º 6
0
    if not os.path.isdir("logs"):
        os.mkdir("logs")
    if not os.path.isdir("model"):
        os.mkdir("model")

    print("Loading data...")
    train_iter, text_field, label_field = data.fasttext_dataloader(
        "data/train.txt", conf.batch_size)
    data.save_vocab(text_field.vocab, "model/text.vocab")
    data.save_vocab(label_field.vocab, "model/label.vocab")

    # Update configurations
    conf.embed_num = len(text_field.vocab)
    conf.class_num = len(label_field.vocab) - 1
    conf.kernel_sizes = [int(k) for k in conf.kernel_sizes.split(',')]

    # model
    if os.path.exists(args.model):
        print('Loading model from {}...'.format(args.model))
        cnn = torch.load(args.model)
    else:
        cnn = model.TextCNN(conf)

    print(cnn)
    try:
        model.train(train_iter, cnn, conf)
    except KeyboardInterrupt:
        print('-' * 80)
        print('Exiting from training early')
Ejemplo n.º 7
0
parse = argparse.ArgumentParser()

parse.add_argument('--mode', default='train', help='train/test')
parse.add_argument('--cuda', default=False)
parse.add_argument('--device', default="3")

args = parse.parse_args()

mode = args.mode
use_cuda = args.cuda
device_id = args.device
if use_cuda:
    torch.cuda.manual_seed(1)
    os.environ["CUDA_VISIBLE_DEVICES"] = device_id

model = model.TextCNN(config)
if use_cuda:
    model.cuda()
if mode == "test":
    print("loading model")
    state_dict = torch.load(open(config.model), 'rb')
    model.load_state_dict(state_dict)


def train_step(train_data, test_data, optimizer):
    model.train()
    count = 0
    total_loss = 0
    for j in range(0, len(train_data), config.batch_size):
        optimizer.zero_grad()
        print("run batch : %d " % j)
Ejemplo n.º 8
0
LABEL.build_vocab(train_data)

# 查看字典长度
print(len(TEXT.vocab))  # 19206
# 查看字典中前10个词语
print(TEXT.vocab.itos[:10])  # ['<unk>', '<pad>', ',', 'the', 'a', 'and', 'of', 'to', '.', 'is']
# 查找'name'这个词对应的词典序号, 本质是一个dict
print(TEXT.vocab.stoi['name'])  # 2063

# 构建迭代(iterator)类型的数据
train_iterator, test_iterator = data.BucketIterator.splits((train_data, test_data),
                                                           batch_size=config.BATCH_SIZE,
                                                           sort=False)

# 创建模型
text_cnn = model.TextCNN(len(TEXT.vocab), config.EMBEDDING_SIZE, len(LABEL.vocab)).to(device)
# 选取优化器
optimizer = optim.Adam(text_cnn.parameters(), lr=config.LEARNING_RATE)
# 选取损失函数
criterion = nn.CrossEntropyLoss()

# 绘制结果
model_train_acc, model_test_acc = [], []
start = time.time()
# 模型训练
for epoch in range(config.EPOCH):
    train_acc = utils.train(text_cnn, train_iterator, optimizer, criterion)
    print("epoch = {}, 训练准确率={}".format(epoch + 1, train_acc))

    test_acc = utils.evaluate(text_cnn, test_iterator)
    print("epoch = {}, 测试准确率={}".format(epoch + 1, test_acc))
Ejemplo n.º 9
0
                               word_idx)  #sentences convert to id matrix

total_num = 0.0
correct_num = 0.0

for index in range(10):  #cross validation
    train_data, test_data, train_label, test_label = mydata.data_train_test(
        data_idx, label, cv, index)
    dataset1 = mydata.subDataset(train_data, train_label)
    train_loader = DataLoader(dataset1,
                              batch_size=100,
                              shuffle=True,
                              num_workers=0)
    dataset2 = mydata.subDataset(test_data, test_label)
    test_loader = DataLoader(dataset2,
                             batch_size=100,
                             shuffle=True,
                             num_workers=0)

    net = model.TextCNN(vec_dim=300,
                        kernel_num=50,
                        vec_num=30,
                        label_num=2,
                        kernel_list=[3, 4, 5])
    train.train_textcnn_model(W, net, train_loader, epoch=10, lr=0.0005)
    a, b = train.textcnn_model_test(W, net, test_loader)
    total_num += a
    correct_num += b

print('Average accuracy of the network on test set: %f %%' %
      (100 * correct_num / total_num))