def trainer(epochs, model, optimizer, scheduler, train_dataloader, test_dataloader, batch_train, batch_test, device): max_grad_norm = 1.0 train_loss_set = [] for e in trange(epochs, desc="Epoch"): while gc.collect() > 0: pass # Training # Set our model to training mode (as opposed to evaluation mode) model.train() # if e > 8: # model.freeze_bert() # Tracking variables tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 # Train the data for one epoch for step, batch in enumerate(train_dataloader): # Add batch to GPU batch = tuple(t.to(device) for t in batch) # Unpack the inputs from our dataloader b_input_ids, b_input_mask, b_adj, b_adj_mwe, b_labels, b_target_idx, _ = batch # Clear out the gradients (by default they accumulate) optimizer.zero_grad() # Forward pass ### For BERT + GCN and MWE loss = model(b_input_ids.to(device), adj=b_adj, adj_mwe=b_adj_mwe ,attention_mask=b_input_mask.to(device), \ labels=b_labels, batch=batch_train, target_token_idx=b_target_idx.to(device)) train_loss_set.append(loss.item()) # Backward pass loss.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Update parameters and take a step using the computed gradient optimizer.step() scheduler.step() optimizer.zero_grad() # Update tracking variables tr_loss += loss.item() nb_tr_examples += b_input_ids.size(0) nb_tr_steps += 1 print("Train loss: {}".format(tr_loss / nb_tr_steps)) # Validation # Put model in evaluation mode to evaluate loss on the validation set model.eval() all_preds = torch.FloatTensor() all_labels = torch.LongTensor() test_indices = torch.LongTensor() # Evaluate data for one epoch for batch in test_dataloader: # Add batch to GPU batch = tuple(t.to(device) for t in batch) # Unpack the inputs from our dataloader b_input_ids, b_input_mask, b_adj, b_adj_mwe, b_labels, b_target_idx, test_idx = batch # Telling the model not to compute or store gradients, saving memory and speeding up validation with torch.no_grad(): # Forward pass, calculate logit predictions ### For BERT + GCN and MWE logits = model(b_input_ids.to(device), adj=b_adj, adj_mwe=b_adj_mwe, attention_mask=b_input_mask.to(device), \ batch=batch_test, target_token_idx=b_target_idx.to(device)) # Move logits and labels to CPU logits = logits.detach().cpu() label_ids = b_labels.cpu() test_idx = test_idx.cpu() all_preds = torch.cat([all_preds, logits]) all_labels = torch.cat([all_labels, label_ids]) test_indices = torch.cat([test_indices, test_idx]) scores = Evaluate(all_preds, all_labels) print('scores.accuracy()={}\nscores.precision_recall_fscore()={}'.format( scores.accuracy(), scores.precision_recall_fscore())) return scores, all_preds, all_labels, test_indices