def main(): train_loader, test_loader = CIFAR10.pytorch_loader() _, _, _, y_test = CIFAR10.numpy() net = DCTI().to(device) optimizer = optim.Adam(net.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() for epoch in range(1, 101): loss = train(net, train_loader, optimizer, criterion) y_prediction = test(net, test_loader).detach().cpu().numpy() print( f"Train epoch: {epoch:>3}\t Loss: {loss:.4f}\t Accuracy: {accuracy(y_test, y_prediction):.2f}" ) torch.save(net.state_dict(), "./model.pth")
def test_cifar10(): CIFAR10.numpy() CIFAR10.pytorch_loader() with pytest.raises(NotImplementedError): CIFAR10.tensorflow_loader()