def test_custom_axes_placement(hang=False): gs1 = gridspec.GridSpec(3, 1, left=0, right=0.5, hspace=0) dbplot(np.sin(np.linspace(0, 10, 100)), 'a', plot_type='line', axis=gs1[0, 0]) dbplot(np.sin(np.linspace(0, 10, 100)+1), 'b', plot_type='line', axis=gs1[1, 0]) dbplot(np.sin(np.linspace(0, 10, 100)+2), 'c', plot_type='line', axis=gs1[2, 0]) gs2 = gridspec.GridSpec(2, 1, left=0.5, right=1, hspace=0.1) dbplot(np.random.randn(20, 20), 'im1', axis=gs2[0, 0]) dbplot(np.random.randn(20, 20, 3), 'im2', axis=gs2[1, 0]) if hang: dbplot_hang()
def demo_converge_to_pareto_curve(layer_sizes=[100, 100, 100, 100], w_scales=[1, 1, 1], n_samples=100, learning_rate=0.01, n_epochs=100, minibatch_size=10, n_random_points_to_try=1000, random_scale_range=(1, 5), parametrization='log', computation_weights=np.logspace(-6, -3, 8), layerwise_scales=True, show_random_scales=True, error_loss='L1', hang_now=True, seed=1234): set_dbplot_default_layout('h') rng = np.random.RandomState(seed) ws = initialize_network_params(layer_sizes=layer_sizes, mag='xavier-relu', include_biases=False, rng=rng) ws = [w * s for w, s in izip_equal(ws, w_scales)] train_data = rng.randn(n_samples, layer_sizes[0]) _, true_out = quantized_forward_pass_cost_and_output( train_data, weights=ws, scales=None, quantization_method=None, seed=1234) # Run the random search scales_to_try = np.abs( rng.normal(loc=np.mean(random_scale_range), scale=np.diff(random_scale_range), size=(n_random_points_to_try, len(ws)))) if show_random_scales: ax = dbplot( scales_to_try.T, 'random_scales', axis='Scales', plot_type=lambda: LinePlot(plot_kwargs=dict(color=(.6, .6, .6)), make_legend=False), xlabel='Layer', ylabel='Scale') ax.set_xticks(np.arange(len(w_scales))) random_flop_counts, random_errors = compute_flop_errors_for_scales( train_data, scales_to_try, ws=ws, quantization_method='round', true_out=true_out, seed=1234) dbplot((random_flop_counts / 1e3 / len(train_data), random_errors), 'random_flop_errors', axis='Tradeoff', xlabel='kOps/sample', ylabel='Error', plot_type=lambda: LinePlot(plot_kwargs=dict( color=(.6, .6, .6), marker='.', linestyle=' '))) # Now run with optimization, across several values of K (total scale) for comp_weight in computation_weights: net = CompErrorScaleOptimizer(ws, optimizer=GradientDescent(learning_rate), comp_weight=comp_weight, layerwise_scales=layerwise_scales, hidden_activations='relu', output_activation='relu', parametrization=parametrization, rng=rng) f_train = net.train_scales.partial(error_loss=error_loss).compile() f_get_scales = net.get_scales.compile() for training_minibatch, iter_info in minibatch_iterate_info( train_data, minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=np.arange(0, n_epochs, 1)): if iter_info.test_now: ks = f_get_scales() with hold_dbplots(): if show_random_scales: dbplot(ks, 'solution_scales ' + str(comp_weight), axis='Scales', plot_type=lambda: LinePlot( plot_kwargs=dict(linewidth=3), make_legend=False, axes_update_mode='expand')) current_flop_counts, current_outputs = quantized_forward_pass_cost_and_output( train_data, weights=ws, scales=ks, quantization_method='round', seed=1234) current_error = np.abs(current_outputs - true_out).mean( ) / np.abs(true_out).mean() if np.isnan(current_error): print 'ERROR IS NAN!!!' dbplot((current_flop_counts / 1e3 / len(train_data), current_error), 'k=%.3g curve' % (comp_weight, ), axis='Tradeoff', plot_type=lambda: Moving2DPointPlot( legend_entries='$\\lambda=%.3g$' % comp_weight, axes_update_mode='expand', legend_entry_size=11)) f_train(training_minibatch) if hang_now: dbplot_hang()
def demo_optimize_conv_scales(n_epochs=5, comp_weight=1e-11, learning_rate=0.1, error_loss='KL', use_softmax=True, optimizer='sgd', shuffle_training=False): """ Run the scale optimization routine on a convnet. :param n_epochs: :param comp_weight: :param learning_rate: :param error_loss: :param use_softmax: :param optimizer: :param shuffle_training: :return: """ if error_loss == 'KL' and not use_softmax: raise Exception( "It's very strange that you want to use a KL divergence on something other than a softmax error. I assume you've made a mistake." ) training_videos, training_vgg_inputs = get_vgg_video_splice( ['ILSVRC2015_train_00033010', 'ILSVRC2015_train_00336001'], shuffle=shuffle_training, shuffling_rng=1234) test_videos, test_vgg_inputs = get_vgg_video_splice( ['ILSVRC2015_train_00033009', 'ILSVRC2015_train_00033007']) set_dbplot_figure_size(12, 6) n_frames_to_show = 10 display_frames = np.arange( len(test_videos) / n_frames_to_show / 2, len(test_videos), len(test_videos) / n_frames_to_show) ax1 = dbplot(np.concatenate(test_videos[display_frames], axis=1), "Test Videos", title='', plot_type='pic') plt.subplots_adjust(wspace=0, hspace=.05) ax1.set_xticks(224 * np.arange(len(display_frames) / 2) * 2 + 224 / 2) ax1.tick_params(labelbottom='on') layers = get_vgg_layer_specifiers( up_to_layer='prob' if use_softmax else 'fc8') # Setup the true VGGnet and get the outputs f_true = ConvNet.from_init(layers, input_shape=(3, 224, 224)).compile() true_test_out = flatten2( np.concatenate([ f_true(frame_positions[None]) for frame_positions in test_vgg_inputs ])) top5_true_guesses = argtopk(true_test_out, 5) true_guesses = np.argmax(true_test_out, axis=1) true_labels = [ get_vgg_label_at(g, short=True) for g in true_guesses[display_frames[::2]] ] full_convnet_cost = np.array([ get_full_convnet_computational_cost(layer_specs=layers, input_shape=(3, 224, 224)) ] * len(test_videos)) # Setup the approximate networks slrc_net = ScaleLearningRoundingConvnet.from_convnet_specs( layers, optimizer=get_named_optimizer(optimizer, learning_rate=learning_rate), corruption_type='rand', rng=1234) f_train_slrc = slrc_net.train_scales.partial( comp_weight=comp_weight, error_loss=error_loss).compile() f_get_scales = slrc_net.get_scales.compile() round_fp = RoundConvNetForwardPass(layers) sigmadelta_fp = SigmaDeltaConvNetForwardPass(layers, input_shape=(3, 224, 224)) p = ProgressIndicator(n_epochs * len(training_videos)) output_dir = make_dir(get_local_path('output/%T-convnet-spikes')) for input_minibatch, minibatch_info in minibatch_iterate_info( training_vgg_inputs, n_epochs=n_epochs, minibatch_size=1, test_epochs=np.arange(0, n_epochs, 0.1)): if minibatch_info.test_now: with EZProfiler('test'): current_scales = f_get_scales() round_cost, round_out = round_fp.get_cost_and_output( test_vgg_inputs, scales=current_scales) sd_cost, sd_out = sigmadelta_fp.get_cost_and_output( test_vgg_inputs, scales=current_scales) round_guesses, round_top1_correct, round_top5_correct = get_and_report_scores( round_cost, round_out, name='Round', true_top_1=true_guesses, true_top_k=top5_true_guesses) sd_guesses, sd_top1_correct, sd_top5_correct = get_and_report_scores( sd_cost, sd_out, name='SigmaDelta', true_top_1=true_guesses, true_top_k=top5_true_guesses) round_labels = [ get_vgg_label_at(g, short=True) for g in round_guesses[display_frames[::2]] ] ax1.set_xticklabels([ '{}\n{}'.format(tg, rg) for tg, rg in izip_equal(true_labels, round_labels) ]) ax = dbplot( np.array([ round_cost / 1e9, sd_cost / 1e9, full_convnet_cost / 1e9 ]).T, 'Computation', plot_type='thick-line', ylabel='GOps', title='', legend=['Round', '$\Sigma\Delta$', 'Original'], ) ax.set_xticklabels([]) plt.grid() dbplot( 100 * np.array( [cummean(sd_top1_correct), cummean(sd_top5_correct)]).T, "Score", plot_type=lambda: LinePlot( y_bounds=(0, 100), plot_kwargs=[ dict(linewidth=3, color='k'), dict(linewidth=3, color='k', linestyle=':') ]), title='', legend=[ 'Round/$\Sigma\Delta$ Top-1', 'Round/$\Sigma\Delta$ Top-5' ], ylabel='Cumulative\nPercent Accuracy', xlabel='Frame #', layout='v', ) plt.grid() plt.savefig( os.path.join(output_dir, 'epoch-%.3g.pdf' % (minibatch_info.epoch, ))) f_train_slrc(input_minibatch) p() print "Epoch {:3.2f}: Scales: {}".format( minibatch_info.epoch, ['%.3g' % float(s) for s in f_get_scales()]) results = dict(current_scales=current_scales, round_cost=round_cost, round_out=round_out, sd_cost=sd_cost, sd_out=sd_out, round_guesses=round_guesses, round_top1_correct=round_top1_correct, round_top5_correct=round_top5_correct, sd_guesses=sd_guesses, sd_top1_correct=sd_top1_correct, sd_top5_correct=sd_top5_correct) dbplot_hang() return results