def train(seed): print('random seed:', seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # torch.backends.cudnn.enabled = False dataset = read_data('../data', '../graph') label2id = dataset.label2id print(label2id) vocab_size = dataset.vocab_size output_dim = len(label2id) def acc_to_str(acc): s = ['%s:%.3f' % (label, acc[label]) for label in acc] return '{' + ', '.join(s) + '}' cross_res = {label: [] for label in label2id if label != 'O'} output_file = open('%s.mistakes' % args.output, 'w') for cross_valid in range(5): model = GNN(vocab_size=vocab_size, output_dim=output_dim, args=args) model.cuda() # print vocab_size dataset.split_train_valid_test([0.8, 0.1, 0.1], 5, cross_valid) print('train:', len(dataset.train), 'valid:', len(dataset.valid), 'test:', len(dataset.test)) def evaluate(model, datalist, output_file=None): if output_file != None: output_file.write( '#############################################\n') correct = {label: 0 for label in label2id if label != 'O'} total = len(datalist) model.eval() print_cnt = 0 for data in datalist: word, feat = Variable(data.input_word).cuda(), Variable( data.input_feat).cuda() a_ud, a_lr = Variable(data.a_ud, requires_grad=False).cuda(), Variable( data.a_lr, requires_grad=False).cuda() mask = Variable(data.mask, requires_grad=False).cuda() if args.globalnode: logprob, form = model(word, feat, mask, a_ud, a_lr) logprob = logprob.data.view(-1, output_dim) else: logprob = model(word, feat, mask, a_ud, a_lr).data.view(-1, output_dim) mask = mask.data.view(-1) y_pred = torch.LongTensor(output_dim) for i in range(output_dim): prob = logprob[:, i].exp() * mask y_pred[i] = prob.topk(k=1)[1][0] # y_pred = logprob.topk(k=1,dim=0)[1].view(-1) for label in label2id: if label == 'O': continue labelid = label2id[label] if data.output.view(-1)[y_pred[labelid]] == labelid: correct[label] += 1 else: if output_file != None: num_sent, sent_len, word_len = data.input_word.size( ) id = y_pred[label2id[label]] word = data.words[data.sents[int( id / sent_len)][id % sent_len]] output_file.write( '%d %d %s %s\n' % (data.set_id, data.fax_id, label, word)) return {label: float(correct[label]) / total for label in correct} batch = 1 weight = torch.zeros(len(label2id)) for label, id in label2id.items(): weight[id] = 1 if label == 'O' else 10 loss_function = nn.NLLLoss(weight.cuda(), reduce=False) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr / float(batch), weight_decay=args.wd) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8) best_acc = -1 wait = 0 for epoch in range(args.epochs): sum_loss = 0 model.train() # random.shuffle(dataset.train) for idx, data in enumerate(dataset.train): word, feat = Variable(data.input_word).cuda(), Variable( data.input_feat).cuda() a_ud, a_lr = Variable(data.a_ud, requires_grad=False).cuda(), Variable( data.a_lr, requires_grad=False).cuda() mask = Variable(data.mask, requires_grad=False).cuda() true_output = Variable(data.output).cuda() if args.globalnode: logprob, form = model(word, feat, mask, a_ud, a_lr) else: logprob = model(word, feat, mask, a_ud, a_lr) loss = torch.mean( mask.view(-1) * loss_function(logprob.view(-1, output_dim), true_output.view(-1))) if args.globalnode: true_form = Variable(torch.LongTensor([data.set_id - 1 ])).cuda() loss = loss + 0.1 * F.nll_loss(form, true_form) sum_loss += loss.data.sum() loss.backward() if (idx + 1) % batch == 0 or idx + 1 == len(dataset.train): optimizer.step() optimizer.zero_grad() train_acc = evaluate(model, dataset.train) valid_acc = evaluate(model, dataset.valid) test_acc = evaluate(model, dataset.test) print('Epoch %d: Train Loss: %.3f Train: %s Valid: %s Test: %s' \ % (epoch, sum_loss, acc_to_str(train_acc), acc_to_str(valid_acc), acc_to_str(test_acc))) # scheduler.step() acc = np.log(list(valid_acc.values())).sum() if epoch < 6: continue if acc >= best_acc: torch.save(model.state_dict(), args.output + '.model') wait = 0 if acc > best_acc else wait + 1 best_acc = max(acc, best_acc) if wait >= args.patience: break model.load_state_dict(torch.load(args.output + '.model')) test_acc = evaluate(model, dataset.test, output_file=output_file) print('########', acc_to_str(test_acc)) for label in test_acc: cross_res[label].append(test_acc[label]) print("Cross Validation Result:") for label in cross_res: cross_res[label] = np.mean(cross_res[label]) print(acc_to_str(cross_res)) return cross_res
def train(dataset): print('random seed:', args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.backends.cudnn.deterministic = True # torch.backends.cudnn.enabled = False cross_res = {label: [] for label in label2id if label != 'O'} for cross_valid in range(1): # print('cross_valid', cross_valid) model = GNN(word_vocab_size=WORD_VOCAB_SIZE, char_vocab_size=CHAR_VOCAB_SIZE, d_output=d_output, args=args) model.cuda() # print vocab_size # print('split dataset') # dataset.split_train_valid_test_bycase([0.5, 0.1, 0.4], 5, cross_valid) print('train:', len(dataset.train), 'valid:', len(dataset.valid), 'test:', len(dataset.test)) sys.stdout.flush() train_dataloader = DataLoader(dataset.train, batch_size=args.batch, shuffle=True) valid_dataloader = DataLoader(dataset.valid, batch_size=args.batch) test_dataloader = DataLoader(dataset.test, batch_size=args.batch) weight = torch.zeros(len(label2id)) for label, idx in label2id.items(): weight[idx] = 1 if label == 'O' else 2 loss_function = nn.CrossEntropyLoss(weight.cuda(), reduce=False) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8) best_acc = -1 wait = 0 batch_cnt = 0 for epoch in range(args.epochs): total_loss = 0 pending_loss = None model.train() # random.shuffle(dataset.train) load_time, forward_time, backward_time = 0, 0, 0 model.clear_time() train_log = open(args.save_path + '_train.log', 'w') for tensors, batch in tqdm(train_dataloader, file=train_log, mininterval=60): # print(batch[0].case_id, batch[0].doc_id, batch[0].page_id) start = time.time() data, data_word, pos, length, mask, label, adjs = to_var( tensors, cuda=args.cuda) batch_size, docu_len, sent_len, word_len = data.size() load_time += (time.time() - start) start = time.time() logit = model(data, data_word, pos, length, mask, adjs) forward_time += (time.time() - start) start = time.time() if args.crf: logit = logit.view(batch_size * docu_len, sent_len, -1) mask = mask.view(batch_size * docu_len, -1) length = length.view(batch_size * docu_len) label = label.view(batch_size * docu_len, -1) loss = -model.crf_layer.loglikelihood( logit, mask, length, label) loss = torch.masked_select(loss, torch.gt(length, 0)).mean() else: loss = loss_function(logit.view(-1, d_output), label.view(-1)) loss = torch.masked_select(loss, mask.view(-1)).mean() total_loss += loss.data.sum() # print(total_loss, batch[0].case_id, batch[0].doc_id, batch[0].page_id) if math.isnan(total_loss): print('Loss is NaN!') exit() loss.backward() optimizer.step() optimizer.zero_grad() backward_time += (time.time() - start) batch_cnt += 1 if batch_cnt % 20000 != 0: continue # print('load %f forward %f backward %f'%(load_time, forward_time, backward_time)) # model.print_time() valid_acc, valid_prec, valid_recall, valid_f1 = evaluate( model, valid_dataloader, args=args) print('Epoch %d: Train Loss: %.3f Valid Acc: %.5f' % (epoch, total_loss, valid_acc)) # print(acc_to_str(valid_f1)) # scheduler.step() acc = np.mean(list(valid_f1.values())) # valid_acc print(acc) if acc >= best_acc: obj = {'args': args, 'model': model.state_dict()} torch.save(obj, args.save_path + '.model') result_obj['valid_prec'] = np.mean( list(valid_prec.values())) result_obj['valid_recall'] = np.mean( list(valid_recall.values())) result_obj['valid_f1'] = np.mean(list(valid_f1.values())) wait = 0 if acc > best_acc else wait + 1 best_acc = max(acc, best_acc) model.train() sys.stdout.flush() if wait >= args.patience: break train_log.close() os.remove(args.save_path + '_train.log') if wait >= args.patience: break obj = torch.load(args.save_path + '.model') model.load_state_dict(obj['model']) test(test_dataloader, model) # print("Cross Validation Result:") # for label in cross_res: # cross_res[label] = np.mean(cross_res[label]) # print(acc_to_str(cross_res)) return cross_res