def discriminator_performance(x, y):
    y, y_pred = to_Var(y), model.eval()(to_Var(x))
    faux_pos = ((y_pred.max(1)[1] != y) * (y_pred.max(1)[1] == 0))
    faux_pos = faux_pos.double().data.sum()
    faux_neg = ((y_pred.max(1)[1] != y) * (y_pred.max(1)[1] == 1))
    faux_neg = faux_neg.double().data.sum()
    total = (y_pred.max(1)[1] != y).double().data.sum()
    return (faux_pos, faux_neg, total)
def big_loss(images, labels):
    data = TensorDataset(images, labels)
    loader = DataLoader(data, batch_size=100, shuffle=False)
    count = 0
    for (x, y) in loader:
        y, y_pred = to_Var(y), model.eval()(to_Var(x))
        count += len(x) * loss_fn(y_pred, y).data.item()
    return count / len(images)
def accuracy(images, labels, k=1):
    data = TensorDataset(images, labels)
    loader = DataLoader(data, batch_size=10, shuffle=False)
    count = 0
    for (x, y) in loader:
        y, y_pred = to_Var(y), model.eval()(to_Var(x))
        y_pred_k = y_pred.topk(k, 1, True, True)[1]
        count += sum(sum((y_pred_k.t() == y).float())).data
        # .double(): ByteTensor sums are limited at 256.
    return 100 * count / len(images)
    bar_format = left + " |{bar}| " + right
    return tqdm(data, desc=epoch, ncols=74, unit='b', bar_format=bar_format)


train_accs, val_accs = [], []
train_losses, val_losses = [], []

try:
    # Main loop over each epoch
    for e in range(epochs):

        # Secondary loop over each mini-batch
        for (x, y) in bar(train_loader, e):

            # Computes the network output
            y_pred = model.train()(to_Var(x))
            loss = loss_fn(y_pred, to_Var(y))

            # Optimizer step
            model.zero_grad()
            loss.backward()
            optimizer.step()

        # Calculates accuracy and loss on the train database.
        train_acc = accuracy(train_images, train_labels, k)
        train_loss = big_loss(train_images, train_labels)
        train_accs.append(train_acc)
        train_losses.append(train_loss)

        # Calculates accuracy and loss on the validation database.
        val_acc = accuracy(val_images, val_labels, k)
Beispiel #5
0
def load_image(img_id):
    return to_Var(images[img_id].view(1, 1, 28, 28))