def plot_momentum(): momenta = [0.0, 0.3, 0.6, 0.9] accuracy_ranges = [] labels = [] for momentum in momenta: test_accuracies = test_mnist_one_hot(momentum=momentum, csv_filename='momentum_{}'.format(int(momentum * 100))) accuracy_ranges.append(test_accuracies) labels.append('Momentum: {}'.format(momentum)) plot_from_list(accuracy_ranges, labels, 'momentum')
def plot_batch_size(): batch_sizes = [1, 10, 100] accuracy_ranges = [] labels = [] for batch_size in batch_sizes: test_accuracies = test_mnist_one_hot(batch_size=batch_size, csv_filename='batch_size_{}'.format(batch_size)) accuracy_ranges.append(test_accuracies) labels.append('Batch size: {}'.format(batch_size)) plot_from_list(accuracy_ranges, labels, 'batch_size')
def plot_learning_rate(): learning_rates = [0.01, 0.02, 0.05, 0.1] accuracy_ranges = [] labels = [] for learning_rate in learning_rates: test_accuracies = test_mnist_one_hot(learning_rate=learning_rate, csv_filename='learning_rate_{}'.format(int(learning_rate * 100))) accuracy_ranges.append(test_accuracies) labels.append('Learning rate: {}'.format(learning_rate)) plot_from_list(accuracy_ranges, labels, 'learning_rate')
def plot_layer_decay(): layer_decays = [0.7, 0.8, 0.9, 0.99, 1] accuracy_ranges = [] labels = [] for layer_decay in layer_decays: test_accuracies = test_mnist_one_hot( layer_decay=layer_decay, csv_filename='layer_decay_{}'.format(int(layer_decay * 100))) accuracy_ranges.append(test_accuracies) labels.append('Layer decay: {}'.format(layer_decay)) plot_from_list(accuracy_ranges, labels, 'layer_decay')
def plot_network_size(): lst_hidden_layers = [(100,), (200,), (300,), (100, 100), (200, 100), (300, 100)] accuracy_ranges = [] labels = [] for hidden_layers in lst_hidden_layers: test_accuracies = test_mnist_one_hot( hidden_layers=hidden_layers, csv_filename='network_size_{}'.format('_'.join(str(layer) for layer in hidden_layers))) accuracy_ranges.append(test_accuracies) labels.append('Hidden layers: {}'.format(hidden_layers)) plot_from_list(accuracy_ranges, labels, 'network_size')
def plot_logistic_vs_tanh(): test_accuracies_logistic = test_mnist_one_hot(sigmoid='logistic', learning_rate=0.1178, csv_filename='logistic') test_accuracies_tanh = test_mnist_one_hot(sigmoid='tanh', csv_filename='tanh') plot_from_list((test_accuracies_logistic, test_accuracies_tanh), ('logistic', 'tanh'), 'logistic_vs_tanh')