示例#1
0
def main():
    train_loader, test_loader = CIFAR10.pytorch_loader()
    _, _, _, y_test = CIFAR10.numpy()
    net = DCTI().to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, 101):
        loss = train(net, train_loader, optimizer, criterion)
        y_prediction = test(net, test_loader).detach().cpu().numpy()
        print(
            f"Train epoch: {epoch:>3}\t Loss: {loss:.4f}\t Accuracy: {accuracy(y_test, y_prediction):.2f}"
        )

    torch.save(net.state_dict(), "./model.pth")
def test_cifar10():
    CIFAR10.numpy()
    CIFAR10.pytorch_loader()
    with pytest.raises(NotImplementedError):
        CIFAR10.tensorflow_loader()