Example #1
0
def evaluate(
    model: TEDD1104,
    data_loader: DataLoader,
    device: torch.device,
    fp16: bool,
) -> float:
    """
    Given a set of input examples and the golds for these examples evaluates the model accuracy
    Input:
     - model: TEDD1104 model to evaluate
     - data_loader: torch.utils.data.DataLoader with the examples to evaluate
     - device: string, use cuda or cpu
     -batch_size: integer batch size
    Output:
    - Accuracy: float
    """
    model.eval()
    correct = 0
    total = 0

    for batch in tqdm(data_loader, desc="Evaluating model"):
        x = torch.flatten(
            torch.stack(
                (
                    batch["image1"],
                    batch["image2"],
                    batch["image3"],
                    batch["image4"],
                    batch["image5"],
                ),
                dim=1,
            ),
            start_dim=0,
            end_dim=1,
        ).to(device)

        y = batch["y"]

        if fp16:
            with autocast():
                predictions: np.ndarray = model.predict(x).cpu()
        else:
            predictions: np.ndarray = model.predict(x).cpu()

        correct += (predictions == y).sum().numpy()
        total += len(predictions)

    return correct / total
def evaluate(
    model: TEDD1104,
    X: torch.tensor,
    golds: torch.tensor,
    device: torch.device,
    batch_size: int,
) -> float:
    """
    Given a set of input examples and the golds for these examples evaluates the model accuracy
    Input:
     - model: TEDD1104 model to evaluate
     - X: input examples [num_examples, sequence_size, 3, H, W]
     - golds: golds for the input examples [num_examples]
     - device: string, use cuda or cpu
     -batch_size: integer batch size
    Output:
    - Accuracy: float
    """
    model.eval()
    correct = 0
    for X_batch, y_batch in nn_batchs(X, golds, batch_size):
        predictions: np.ndarray = model.predict(X_batch.to(device)).cpu().numpy()
        correct += np.sum(predictions == y_batch)

    return correct / len(golds)