コード例 #1
0
def test(text_inputs, attention_inputs, labels, model, criterion, device,
         n_way):
    # An F1 Score of 0 indicates that it is invalid
    model.eval()
    true_positive = list(0. for i in range(
        n_way))  # Number of correctly predicted samples per class
    total_truth = list(
        0. for i in range(n_way))  # Number of ground truths per class
    predicted_positive = list(
        0. for i in range(n_way))  # Number of predicted samples per class
    correct_total = 0  # Total correctly predicted samples
    total = 0  # Total samples
    with torch.no_grad():
        # Test the entire query set in one batch
        pred = model(text_inputs, attention_inputs)
        loss = criterion(pred, labels)
        val_loss = loss.item()  # Running validation loss
        _, predicted = torch.max(pred, 1)
        correct = (predicted == labels
                   ).squeeze()  # Samples that are correctly predicted
        correct_total += (predicted == labels).sum().item()
        total += labels.size(0)

        for i in range(len(predicted)):
            label = labels[i]
            true_positive[label] += correct[i].item()
            total_truth[label] += 1
            predicted_positive[
                predicted[i].item()] += 1  # True Positive + False Positive

    accuracy, macro_accuracy, f1_score, class_f1 = metrics(
        true_positive, total_truth, predicted_positive, correct_total, total)

    return val_loss, accuracy, macro_accuracy, f1_score, class_f1
コード例 #2
0
def test(image_inputs, text_inputs, attention_inputs, labels, model, criterion,
         device, n_way):
    """
    Testing Loop

    :param image_inputs: The batch of image testing inputs
    :param text_inputs: The tokenized batch of text testing inputs
    :param attention_inputs: The attention mask that indicates which value of the text inputs are padding or not
    :param labels: The labels of the batched training data
    :param model: The model that the data will be tested on
    :param criterion: The loss function
    :param device: The type of device that the training will occur on
    :param n_way: The number of classes that the data can take
    :return: A tuple containing the validation loss, validation accuracy, average class accuracy,
    Macro-F1 score and a list of class F1 scores
    """
    # An F1 Score of 0 indicates that it is invalid
    model.eval()
    true_positive = list(0. for i in range(
        n_way))  # Number of correctly predicted samples per class
    total_truth = list(
        0. for i in range(n_way))  # Number of ground truths per class
    predicted_positive = list(
        0. for i in range(n_way))  # Number of predicted samples per class
    correct_total = 0  # Total correctly predicted samples
    total = 0  # Total samples
    with torch.no_grad():
        # Test the entire query set in one batch
        pred = model(image_inputs, text_inputs, attention_inputs)
        loss = criterion(pred, labels)
        val_loss = loss.item()  # Running validation loss
        _, predicted = torch.max(pred, 1)
        correct = (predicted == labels
                   ).squeeze()  # Samples that are correctly predicted
        correct_total += (predicted == labels).sum().item()
        total += labels.size(0)

        for i in range(len(predicted)):
            label = labels[i]
            true_positive[label] += correct[i].item()
            total_truth[label] += 1
            predicted_positive[
                predicted[i].item()] += 1  # True Positive + False Positive

    accuracy, macro_accuracy, f1_score, class_f1 = metrics(
        true_positive, total_truth, predicted_positive, correct_total, total)

    return val_loss, accuracy, macro_accuracy, f1_score, class_f1
コード例 #3
0
def test(model, test_loader, criterion, device, n_way):
    """
    Testing Loop

    :param model: The model that the data will be tested on
    :param test_loader: The data loader which will give the batched testing data
    :param criterion: The loss function
    :param device: The type of device that the training will occur on
    :param n_way: The number of classes that the data can take
    :return: A tuple containing the validation loss, validation accuracy, average class accuracy,
    Macro-F1 score and a list of class F1 scores
    """
    # An F1 Score of 0 indicates that it is invalid
    model.eval()
    true_positive = list(0. for i in range(
        n_way))  # Number of correctly predicted samples per class
    total_truth = list(
        0. for i in range(n_way))  # Number of ground truths per class
    predicted_positive = list(
        0. for i in range(n_way))  # Number of predicted samples per class
    val_loss = 0
    correct_total = 0  # Total correctly predicted samples
    total = 0  # Total samples
    with torch.no_grad():
        for step, (data_inputs, data_labels) in enumerate(test_loader):
            inputs, labels = data_inputs.to(device), data_labels.to(device)
            pred = model(inputs)
            loss = criterion(pred, labels)
            val_loss += loss.item()  # Running validation loss
            _, predicted = torch.max(pred, 1)
            correct = (predicted == labels
                       ).squeeze()  # Samples that are correctly predicted
            correct_total += (predicted == labels).sum().item()
            total += labels.size(0)

            for i in range(len(predicted)):
                label = labels[i]
                true_positive[label] += correct[i].item()
                total_truth[label] += 1
                predicted_positive[
                    predicted[i].item()] += 1  # True Positive + False Positive

    accuracy, macro_accuracy, f1_score, class_f1 = metrics(
        true_positive, total_truth, predicted_positive, correct_total, total)

    return val_loss / (step + 1), accuracy, macro_accuracy, f1_score, class_f1
コード例 #4
0
def get_metrics(logits_q, querysz, y_qry, n_way):
    """
    Calculates the accuracy, average class accuracy, macro-F1 score and individual class F1 scores based on the inputs

    :param logits_q: The logits obtained from the query set
    :type logits_q: torch.Tensor
    :param querysz: The is the size of the query set
    :type querysz: int
    :param y_qry: The actual labels of the query set
    :type y_qry: torch.Tensor
    :param n_way: The number of classes for the episode
    :type n_way: int
    :return: A tuple of accuracy, average class accuracy, Macro-F1_score and a list of class F1 scores
    :rtype: tuple
    """
    true_positive = list(0. for i in range(
        n_way))  # Number of correctly predicted samples per class
    total_truth = list(
        0. for i in range(n_way))  # Number of ground truths per class
    predicted_positive = list(
        0. for i in range(n_way))  # Number of predicted samples per class
    correct_total = 0  # Total correctly predicted samples

    # Find variables for use in the metrics function
    _, predicted = torch.max(logits_q, 1)
    correct = (predicted == y_qry).squeeze()
    correct_total += (predicted == y_qry).sum().item()

    # Fill out the required parameters to be used in the metric function
    for i in range(len(predicted)):
        label = y_qry[i]
        true_positive[label] += correct[i].item()
        total_truth[label] += 1
        predicted_positive[
            predicted[i].item()] += 1  # True Positive + False Positive

    accuracy, macro_accuracy, f1_score, class_f1 = metrics(
        true_positive, total_truth, predicted_positive, correct_total, querysz)

    return accuracy, macro_accuracy, f1_score, class_f1
コード例 #5
0
def main():
    if (
        (args["ld"] == "" and args["ed"] == "") or 
        (args["ld"] != "" and args["ed"] != "")
    ):
        sys.exit(f"needs exactly one of --ld and --ed")


    dataset, labels = {}, {}
    if args["ld"] != "":
        print(f"attempting to load dataset from pickle file {args['ld']}")
        dataset, labels = pickle.load(open(args["ld"], "rb"))
    else:
        if not os.path.isdir(args["ed"]):
            sys.exit(f"{args['ed']} is not a directory")

        mon_dir = os.path.join(args["ed"], "monitored")
        if not os.path.isdir(mon_dir):
            sys.exit(f"{mon_dir} is not a directory")

        unm_dir = os.path.join(args["ed"], "unmonitored")
        if not os.path.isdir(unm_dir):
            sys.exit(f"{unm_dir} is not a directory")

        print(f"{now()} starting to load dataset from folder...")
        dataset, labels = shared.load_dataset(
            mon_dir,
            unm_dir,
            args["c"],
            args["p"],
            args["s"],
            args["l"],
            shared.trace2cells
        )
        if args["sd"] != "":
            pickle.dump((dataset, labels), open(args["sd"], "wb"))
            print(f"saved dataset to {args['sd']}")

    print(f"{now()} loaded {len(dataset)} items in dataset with {len(labels)} labels")

    split = shared.split_dataset(args["c"], args["p"], args["s"], args["f"], labels)
    print(
        f"{now()} split {len(split['train'])} training, "
        f"{len(split['validation'])} validation, and "
        f"{len(split['test'])} testing"
    )

    if args["z"] != "":
        dataset = shared.zero_dataset(dataset, args["z"])
        print(f"{now()} zeroed each item in dataset as data[{args['z']}]")

    model = DFNet(args["c"]+1) # one class for unmonitored
    if args["lm"] != "":
        model = torch.load(args["lm"])
        print(f"loaded model from {args['lm']}")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print(f"{now()} using {torch.cuda.get_device_name(0)}")
        model.cuda()

    if args["train"]:
        # Note below that shuffle=True is *essential*, 
        # see https://stackoverflow.com/questions/54354465/
        train_gen = data.DataLoader(
            shared.Dataset(split["train"], dataset, labels),
            batch_size=args["batchsize"], shuffle=True,
        )
        validation_gen = data.DataLoader(
            shared.Dataset(split["validation"], dataset, labels),
            batch_size=args["batchsize"], shuffle=True,
        )

        optimizer = torch.optim.Adamax(params=model.parameters())
        criterion = torch.nn.CrossEntropyLoss()

        for epoch in range(args["epochs"]):
            print(f"{now()} epoch {epoch}")

            # training
            model.train()
            torch.set_grad_enabled(True)
            running_loss = 0.0
            n = 0
            for x, Y in train_gen:
                x, Y = x.to(device), Y.to(device)
                optimizer.zero_grad()
                outputs = model(x)
                loss = criterion(outputs, Y)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                n+=1
            print(f"\ttraining loss {running_loss/n}")

            # validation
            model.eval()
            torch.set_grad_enabled(False)
            running_corrects = 0
            n = 0
            for x, Y in validation_gen:
                x, Y = x.to(device), Y.to(device)

                outputs = model(x)
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == Y)
                n += len(Y)
            print(f"\tvalidation accuracy {float(running_corrects)/float(n)}")

        if args["sm"] != "":
            torch.save(model, args["sm"])
            print(f"saved model to {args['sm']}")
    
    # testing
    testing_gen = data.DataLoader(
        shared.Dataset(split["test"], dataset, labels), 
        batch_size=args["batchsize"]
    )
    model.eval()
    torch.set_grad_enabled(False)
    predictions = []
    p_labels = []
    for x, Y in testing_gen:
        x = x.to(device)
        outputs = model(x)
        index = F.softmax(outputs, dim=1).data.cpu().numpy()
        predictions.extend(index.tolist())
        p_labels.extend(Y.data.numpy().tolist())

    print(f"{now()} made {len(predictions)} predictions with {len(p_labels)} labels")
    csvline = []
    threshold = np.append([0], 1.0 - 1 / np.logspace(0.05, 2, num=15, endpoint=True))
    threshold = np.around(threshold, decimals=4)
    for th in threshold:
        tp, fpp, fnp, tn, fn, accuracy, recall, precision, f1 = shared.metrics(th, 
                                            predictions, p_labels, args["c"])
        print(
            f"\tthreshold {th:4.2}, "
            f"recall {recall:4.2}, "
            f"precision {precision:4.2}, "
            f"F1 {f1:4.2}, "
            f"accuracy {accuracy:4.2}   "
            f"[tp {tp:>5}, fpp {fpp:>5}, fnp {fnp:>5}, tn {tn:>5}, fn {fn:>5}]"
        )
        csvline.append([
            th, recall, precision, f1, tp, fpp, fnp, tn, fn, args["extra"]
        ])

    if args["csv"]:
        with open(args["csv"], "w", newline="") as csvfile:
            w = csv.writer(csvfile, delimiter=",")
            w.writerow(["th", "recall", "precision", "f1", "tp", "fpp", "fnp", "tn", "fn", "extra"])
            w.writerows(csvline)
        print(f"saved testing results to {args['csv']}")