Example #1
0
import matplotlib.pyplot as plt

from cnn import data_utils
from cnn.solver import Solver
from cnn import CNN
import numpy as np
data = data_utils.get_CIFAR10_data()
model = CNN.ThreeLayerConvNet(reg=0.9)
solver = Solver(model,
                data,
                lr_decay=0.95,
                print_every=10,
                num_epochs=5,
                batch_size=8,
                update_rule='sgd_momentum',
                optim_config={
                    'learning_rate': 5e-4,
                    'momentum': 0.9
                })

solver.train()

plt.subplot(2, 1, 1)
plt.title('Training loss')
plt.plot(solver.loss_history, 'o')
plt.xlabel('Iteration')

plt.subplot(2, 1, 2)
plt.title('Accuracy')
plt.plot(solver.train_acc_history, '-o', label='train')
plt.plot(solver.val_acc_history, '-o', label='val')