Пример #1
0
def predict(test_loader, model, config, device):  # for test set
    """
    predicts labels on unseen data (test set)
    :param test_loader: dataloader torch object with test data
    :param model: trained model
    :param config: config: config json file
    :return: predictions for the given dataset, the loss and accuracy over the whole dataset
    """
    test_loss = []
    predictions = []
    accuracy = []
    f1_scores = []
    model.to(device)
    for batch in test_loader:
        batch["device"] = device
        out = model(batch).to("cpu")
        if config["model"]["classification"] == "multi":
            test_loss.append(multi_class_cross_entropy(out, batch["l"]).item())
            _, prediction = torch.max(out, 1)
            prediction = prediction.tolist()
        else:  # binary
            test_loss.append(
                binary_class_cross_entropy(out.squeeze(),
                                           batch["l"].float()).item())
            prediction = convert_logits_to_binary_predictions(out)
        _, accur = get_accuracy(prediction, batch["l"])
        f1 = f1_score(y_pred=prediction, y_true=batch["l"], average="weighted")
        predictions.append(prediction)
        accuracy.append(accur)
        f1_scores.append(f1)
    predictions = [item for sublist in predictions
                   for item in sublist]  # flatten list
    return predictions, np.average(test_loss), np.average(
        accuracy), np.average(f1_scores)
Пример #2
0
 def test_multiclass_cross_entropy(self):
     """
     multiclass classification: tests whether loss from implemented function corresponds to loss manually calculated
     """
     expected_loss = self.calculate_manually(self.output_tensor, self.target_labels)
     loss = multi_class_cross_entropy(output=self.output_tensor, target=self.target_labels)
     np.testing.assert_allclose(loss, expected_loss)
Пример #3
0
 def test_matrix_transfer(self):
     """Test whether the transfer matrix classification model can be called and whether the loss can be computed"""
     batch = next(iter(self.data_loader))
     batch["device"] = "cpu"
     output = self.model_transfer(batch)
     loss = multi_class_cross_entropy(output, batch["l"]).item()
     np.testing.assert_equal(math.isnan(loss), False)
     np.testing.assert_equal(loss >= 0, True)
Пример #4
0
 def test_model_multiclass(self):
     """
     Test whether the multiclass classifier can be ran and whether the loss can be computed. The loss should be
     a number larger than zero and not NaN
     """
     self.optimizer_multi.zero_grad()
     output = self.model_multiclass(self.batch_multi)
     loss = multi_class_cross_entropy(output, self.label_multi).item()
     np.testing.assert_equal(math.isnan(loss), False)
     np.testing.assert_equal(loss >= 0, True)
Пример #5
0
 def train_matrix_classifier(self):
     """Auxiliary method to train and save a normal classification matrix model"""
     optimizer = optim.Adam(self.model_multiclass.parameters())
     for batch in self.data_loader:
         optimizer.zero_grad()
         batch["device"] = "cpu"
         output = self.model_multiclass(batch)
         loss = multi_class_cross_entropy(output, batch["l"])
         loss.backward()
         optimizer.step()
     torch.save(self.model_multiclass.state_dict(),
                "models/matrix_classifier")
Пример #6
0
    def test_model_multiclass(self):
        """
        Test whether the multiclass classifier can be ran and whether the loss can be computed. The loss should be
        a number larger than zero and not NaN
        """
        data_loader = DataLoader(self._static_dataset,
                                 batch_size=64,
                                 shuffle=True,
                                 num_workers=0)

        for batch in data_loader:
            # context is a list of list of word embeddings
            batch["device"] = "cpu"
            out = self.model(batch).squeeze()
            loss = multi_class_cross_entropy(out, batch["l"])
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            break
        loss = loss.data.numpy()
        np.testing.assert_equal(math.isnan(loss), False)
        np.testing.assert_equal(loss >= 0, True)
Пример #7
0
def train_multiclass(config, train_loader, valid_loader, model_path, device):
    """
    method to train a multiclass classification model
    :param config: config json file
    :param train_loader: dataloader torch object with training data
    :param valid_loader: dataloader torch object with validation data
    :return: the trained model
    """
    model = init_classifier(config)
    model.to(device)
    optimizer = optim.Adam(model.parameters(
    ))  # or make an if statement for choosing an optimizer
    current_patience = 0
    tolerance = 1e-5
    lowest_loss = float("inf")
    best_epoch = 1
    epoch = 1
    train_loss = 0.0
    best_accuracy = 0.0
    best_f1 = 0.0
    early_stopping_criterion = config["validation_metric"]
    total_train_losses = []
    total_val_losses = []
    for epoch in range(1, config["num_epochs"] + 1):
        # training loop over all batches
        model.train()
        # these store the losses and accuracies for each batch for one epoch
        train_losses = []
        valid_losses = []
        valid_accuracies = []
        valid_f1_scores = []
        for batch in train_loader:
            batch["device"] = device
            out = model(batch).to("cpu")
            loss = multi_class_cross_entropy(out, batch["l"])
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            train_losses.append(loss.item())

        model.eval()
        for batch in valid_loader:
            batch["device"] = device
            out = model(batch).to("cpu")
            _, predictions = torch.max(out, 1)
            loss = multi_class_cross_entropy(out, batch["l"])
            valid_losses.append(loss.item())
            _, accur = get_accuracy(predictions.tolist(), batch["l"])
            f1 = f1_score(y_true=batch["l"],
                          y_pred=predictions,
                          average="weighted")
            valid_accuracies.append(accur)

            valid_f1_scores.append(f1)

        # calculate average loss and accuracy over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        valid_accuracy = np.average(valid_accuracies)
        valid_f1 = np.average(valid_f1_scores)
        total_train_losses.append(train_loss)
        total_val_losses.append(valid_loss)

        # stop when f1 score is the highest
        if early_stopping_criterion == "f1":
            if valid_f1 > best_f1 - tolerance:
                lowest_loss = valid_loss
                best_f1 = valid_f1
                best_epoch = epoch
                best_accuracy = valid_accuracy
                current_patience = 0
                torch.save(model.state_dict(), model_path)
            else:
                current_patience += 1
        # stop when loss is the lowest
        else:
            if lowest_loss - valid_loss > tolerance:
                lowest_loss = valid_loss
                best_epoch = epoch
                best_accuracy = valid_accuracy
                best_f1 = valid_f1
                current_patience = 0
                torch.save(model.state_dict(), model_path)

            else:
                current_patience += 1
        if current_patience > config["patience"]:
            break

        logger.info(
            "current patience: %d , epoch %d , train loss: %.5f, validation loss: %.5f, accuracy: %.5f, f1 score: %5f"
            % (current_patience, epoch, train_loss, valid_loss, valid_accuracy,
               valid_f1))
    logger.info(
        "training finnished after %d epochs, train loss: %.5f, best epoch : %d , best validation loss: %.5f, "
        "best validation accuracy: %.5f, best f1 score %.5f" %
        (epoch, train_loss, best_epoch, lowest_loss, best_accuracy, best_f1))
    if config["plot_curves"]:
        path = str(
            Path(config["model_path"]).joinpath(save_name +
                                                "_learning_curves.png"))
        plot_learning_curves(training_losses=total_train_losses,
                             validation_losses=total_val_losses,
                             save_path=path)