import matplotlib.pyplot as plt from cnn import data_utils from cnn.solver import Solver from cnn import CNN import numpy as np data = data_utils.get_CIFAR10_data() model = CNN.ThreeLayerConvNet(reg=0.9) solver = Solver(model, data, lr_decay=0.95, print_every=10, num_epochs=5, batch_size=8, update_rule='sgd_momentum', optim_config={ 'learning_rate': 5e-4, 'momentum': 0.9 }) solver.train() plt.subplot(2, 1, 1) plt.title('Training loss') plt.plot(solver.loss_history, 'o') plt.xlabel('Iteration') plt.subplot(2, 1, 2) plt.title('Accuracy') plt.plot(solver.train_acc_history, '-o', label='train') plt.plot(solver.val_acc_history, '-o', label='val')