def run_experiment(average_gradients, batch_size, iterations, verbose): batch_size = batch_size tf.reset_default_graph() net = ConvNet() validation_batch = mnist.test.images val_count = validation_batch.shape[0] validation_batch = np.reshape(validation_batch, (val_count, 28, 28, 1)) validation_labels = mnist.test.labels net.setup_train(average_gradients=average_gradients) training_log = [] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(iterations): batch = mnist.train.next_batch(batch_size) input_batch = np.reshape(batch[0], (batch_size, 28, 28, 1)) loss = net.train(sess, input_batch, batch[1]) if (i + 1) % 100 == 0: accuracy = net.evaluate(sess, validation_batch, validation_labels) training_log.append((accuracy, i + 1)) if verbose: print('[{:d}/{:d}] loss: {:.3g}, accuracy: {:.3g}%'.format( i + 1, iterations, loss, accuracy)) accuracy = net.evaluate(sess, validation_batch, validation_labels) training_log.append((accuracy, iterations)) best = sorted(training_log, key=lambda x: x[0], reverse=True)[0] print('Training finished. Best accuracy: {:.3g} at iteration {:d}.'. format(best[0], best[1])) return best[0]
def experiment(threshold, iterations, train_loss, n_conv, optimizer, batch_size=1, batch_norm=False, learning_rate=1e-3, summary_dir=None): model = ConvNet(filters=4, n_conv=n_conv, train_loss=train_loss, batch_norm=batch_norm, optimizer=optimizer, learning_rate=learning_rate, summary_dir=summary_dir) print('train_loss:', train_loss.value, 'optimizer:', optimizer.value, 'n_conv:', n_conv, 'batch_norm:', batch_norm, 'batch_size:', batch_size, 'learning_rate:', learning_rate) ret = dict() val_input_batch, val_output_batch = get_data(threshold, 100, verbose=True) best_accuracy = (0.0, 0) for i in tqdm(range(iterations)): input_batch, output_batch = get_data(threshold, batch_size) out = model.train(input_batch, output_batch) if i == 0: for k in out.keys(): ret[k] = [] ret['accuracy'] = [] for k, v in out.items(): ret[k].append(v) if i % 250 == 0: accuracy = model.accuracy(val_input_batch, val_output_batch) if accuracy > best_accuracy[0]: best_accuracy = (accuracy, i) ret['accuracy'].append((i, accuracy)) #print('[%d] accuracy: %.3g' % (i, accuracy)) print('Best accuracy %.3g at iteration %d.' % best_accuracy) return ret