Example #1
0
def main():
    """CapsNet run as module.

    Run full cycle when CapsNet is run as a module.
    """
    people = fetch_lfw_people(
        color=True,
        min_faces_per_person=25,
        # resize=1.,
        # slice_=(slice(48, 202), slice(48, 202))
    )

    data = preprocess(people)

    (x_train, y_train), (x_test, y_test) = data  # noqa

    model = CapsNet(x_train.shape[1:], len(np.unique(y_train, axis=0)))

    model.summary()

    # Start TensorBoard
    tensorboard = callbacks.TensorBoard('model/tensorboard_logs',
                                        batch_size=10,
                                        histogram_freq=1,
                                        write_graph=True,
                                        write_grads=True,
                                        write_images=True)
    model.train(data, batch_size=10, extra_callbacks=[tensorboard])
    model.save('/tmp')

    metrics = model.test(x_test, y_test)
    pprint(metrics)
Example #2
0
from capsnet import CapsNet
from functions import DigitMarginLoss
from functions import accuracy

train_loader = DataLoader(datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([
    # transforms.RandomShift(2),
    transforms.ToTensor()])), batch_size=1, shuffle=True)

test_loader = DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([
    transforms.ToTensor()])), batch_size=1)

model = CapsNet()
optimizer = optim.Adam(model.parameters())
margin_loss = DigitMarginLoss()
reconstruction_loss = torch.nn.MSELoss(size_average=False)
model.train()

for epoch in range(1, 11):
    epoch_tot_loss = 0
    epoch_tot_acc = 0
    for batch, (data, target) in enumerate(train_loader, 1):
        data = Variable(data)
        target = Variable(target)

        digit_caps, reconstruction = model(data, target)
        loss = margin_loss(digit_caps, target) + 0.0005 * reconstruction_loss(reconstruction, data.view(-1))
        epoch_tot_loss += loss

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
Example #3
0
def main():
    # Load model
    model = CapsNet().to(device)
    criterion = CapsuleLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)

    # Load data
    transform = transforms.Compose([
        # shift by 2 pixels in either direction with zero padding.
        transforms.RandomCrop(28, padding=2),
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081, )),
    ])
    DATA_PATH = "./data"
    BATCH_SIZE = 128
    train_loader = DataLoader(
        dataset=MNIST(root=DATA_PATH,
                      download=True,
                      train=True,
                      transform=transform),
        batch_size=BATCH_SIZE,
        num_workers=4,
        shuffle=True,
    )
    test_loader = DataLoader(
        dataset=MNIST(root=DATA_PATH,
                      download=True,
                      train=False,
                      transform=transform),
        batch_size=BATCH_SIZE,
        num_workers=4,
        shuffle=True,
    )

    # Train
    EPOCHES = 50
    model.train()
    for ep in range(EPOCHES):
        batch_id = 1
        correct, total, total_loss = 0, 0, 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            images = images.to(device)
            labels = torch.eye(10).index_select(dim=0, index=labels).to(device)
            logits, reconstruction = model(images)

            # Compute loss & accuracy
            loss = criterion(images, labels, logits, reconstruction)
            correct += torch.sum(
                torch.argmax(logits, dim=1) == torch.argmax(labels,
                                                            dim=1)).item()
            total += len(labels)
            accuracy = correct / total
            total_loss += loss
            loss.backward()
            optimizer.step()
            print("Epoch {}, batch {}, loss: {}, accuracy: {}".format(
                ep + 1, batch_id, total_loss / batch_id, accuracy))
            batch_id += 1
        scheduler.step(ep)
        print("Total loss for epoch {}: {}".format(ep + 1, total_loss))

    # Eval
    model.eval()
    correct, total = 0, 0
    for images, labels in test_loader:
        # Add channels = 1
        images = images.to(device)
        # Categogrical encoding
        labels = torch.eye(10).index_select(dim=0, index=labels).to(device)
        logits, reconstructions = model(images)
        pred_labels = torch.argmax(logits, dim=1)
        correct += torch.sum(pred_labels == torch.argmax(labels, dim=1)).item()
        total += len(labels)
    print("Accuracy: {}".format(correct / total))

    # Save model
    torch.save(
        model.state_dict(),
        "./model/capsnet_ep{}_acc{}.pt".format(EPOCHES, correct / total),
    )