def evaluate(model, iterator, criterion, label_list): model.eval() epoch_loss = 0 all_preds = np.array([], dtype=int) all_labels = np.array([], dtype=int) with torch.no_grad(): for batch in iterator: with torch.no_grad(): logits = model(batch) loss = criterion(logits.view(-1, len(label_list)), batch.label) labels = batch.label.detach().cpu().numpy() preds = np.argmax(logits.detach().cpu().numpy(), axis=1) all_preds = np.append(all_preds, preds) all_labels = np.append(all_labels, labels) epoch_loss += loss.item() acc, report = classifiction_metric( all_preds, all_labels, label_list) return epoch_loss / len(iterator), acc, report
def train(epoch_num, model, train_dataloader, dev_dataloader, optimizer, criterion, label_list, out_model_file, log_dir, print_step, clip): model.train() logging_dir = log_dir + time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time())) writer = SummaryWriter( log_dir=logging_dir) global_step = 0 best_dev_loss = float('inf') best_acc = 0.0 for epoch in range(int(epoch_num)): print('---------------- Epoch: {epoch + 1:02} ----------') epoch_loss = 0 train_steps = 0 all_preds = np.array([], dtype=int) all_labels = np.array([], dtype=int) for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): # zero the parameter gradients optimizer.zero_grad() logits = model(batch) loss = criterion(logits.view(-1, len(label_list)), batch.label) labels = batch.label.detach().cpu().numpy() preds = np.argmax(logits.detach().cpu().numpy(), axis=1) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() global_step += 1 epoch_loss += loss.item() train_steps += 1 all_preds = np.append(all_preds, preds) all_labels = np.append(all_labels, labels) if global_step % print_step == 0: train_loss = epoch_loss / train_steps train_acc, train_report = classifiction_metric( all_preds, all_labels, label_list) dev_loss, dev_acc, dev_report = evaluate( model, dev_dataloader, criterion, label_list) c = global_step // print_step writer.add_scalar("loss/train", train_loss, c) writer.add_scalar("loss/dev", dev_loss, c) writer.add_scalar("acc/train", train_acc, c) writer.add_scalar("acc/dev", dev_acc, c) for label in label_list: writer.add_scalar(label + ":" + "f1/train", train_report[label]['f1-score'], c) writer.add_scalar(label + ":" + "f1/dev", dev_report[label]['f1-score'], c) print_list = ['macro avg', 'weighted avg'] for label in print_list: writer.add_scalar(label + ":" + "f1/train", train_report[label]['f1-score'], c) writer.add_scalar(label + ":" + "f1/dev", dev_report[label]['f1-score'], c) # if dev_loss < best_dev_loss: # best_dev_loss = dev_loss if dev_acc > best_acc: best_acc = dev_acc torch.save(model.state_dict(), out_model_file) model.train() writer.close()