Beispiel #1
0
def trainer(epochs, model, optimizer, scheduler, train_dataloader,
            test_dataloader, batch_train, batch_test, device):

    max_grad_norm = 1.0
    train_loss_set = []

    for e in trange(epochs, desc="Epoch"):

        while gc.collect() > 0:
            pass

        # Training
        # Set our model to training mode (as opposed to evaluation mode)
        model.train()

        # if e > 8:
        #     model.freeze_bert()

        # Tracking variables
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        # Train the data for one epoch
        for step, batch in enumerate(train_dataloader):
            # Add batch to GPU
            batch = tuple(t.to(device) for t in batch)
            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_adj, b_adj_mwe, b_labels, b_target_idx, _ = batch

            # Clear out the gradients (by default they accumulate)
            optimizer.zero_grad()
            # Forward pass
            ### For BERT + GCN and MWE
            loss = model(b_input_ids.to(device), adj=b_adj, adj_mwe=b_adj_mwe ,attention_mask=b_input_mask.to(device), \
                        labels=b_labels, batch=batch_train, target_token_idx=b_target_idx.to(device))

            train_loss_set.append(loss.item())
            # Backward pass
            loss.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            # Update parameters and take a step using the computed gradient
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            # Update tracking variables
            tr_loss += loss.item()
            nb_tr_examples += b_input_ids.size(0)
            nb_tr_steps += 1

        print("Train loss: {}".format(tr_loss / nb_tr_steps))

        # Validation

        # Put model in evaluation mode to evaluate loss on the validation set
        model.eval()

        all_preds = torch.FloatTensor()
        all_labels = torch.LongTensor()
        test_indices = torch.LongTensor()

        # Evaluate data for one epoch
        for batch in test_dataloader:
            # Add batch to GPU
            batch = tuple(t.to(device) for t in batch)
            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_adj, b_adj_mwe, b_labels, b_target_idx, test_idx = batch
            # Telling the model not to compute or store gradients, saving memory and speeding up validation
            with torch.no_grad():
                # Forward pass, calculate logit predictions
                ### For BERT + GCN and MWE
                logits = model(b_input_ids.to(device), adj=b_adj, adj_mwe=b_adj_mwe, attention_mask=b_input_mask.to(device), \
                               batch=batch_test, target_token_idx=b_target_idx.to(device))

                # Move logits and labels to CPU
                logits = logits.detach().cpu()
                label_ids = b_labels.cpu()
                test_idx = test_idx.cpu()

                all_preds = torch.cat([all_preds, logits])
                all_labels = torch.cat([all_labels, label_ids])
                test_indices = torch.cat([test_indices, test_idx])

    scores = Evaluate(all_preds, all_labels)
    print('scores.accuracy()={}\nscores.precision_recall_fscore()={}'.format(
        scores.accuracy(), scores.precision_recall_fscore()))

    return scores, all_preds, all_labels, test_indices