def demo_temporal_mnist(n_samples = None, smoothing_steps = 200): _, _, original_data, original_labels = get_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples).xyxy _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples, smoothing_steps=smoothing_steps).xyxy for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data, temporal_labels): with hold_dbplots(): dbplot(ox, 'sample', title = str(oy)) dbplot(tx, 'smooth', title = str(ty))
def test_periodic_plotting(): for t in range(100): with hold_dbplots(draw_every='1s'): dbplot(np.sin(t/10), 'sinusoid') dbplot(np.cos(t/10), 'cosinusoid') time.sleep(0.02)
def demo_settling_dynamics(symmetric=False, n_hidden=50, n_out=3, input_influence=0.01, learning_rate=0.0001, cut_time=None, minibatch_size=1, decay=0.05, scale=.4, hidden_act='tanh', output_act='lin', draw_every=10, n_steps=10000, seed=124): """ Here we use Predictive Coding and compare_learning_curves the convergence of a predictive-coded network to one without. """ rng = get_rng(seed) net_d = Network.from_init(symmetric=symmetric, n_hidden=n_hidden, n_out=n_out, scale=scale, fh=hidden_act, fx=output_act, decay=decay, rng=rng) state_d = net_d.init_state(minibatch_size=minibatch_size) net_l = Network.from_init(symmetric=symmetric, n_hidden=n_hidden, n_out=n_out, scale=scale, fh=hidden_act, fx=output_act, decay=decay, rng=rng, input_influence=input_influence, learning_rate=learning_rate) state_l = net_l.init_state(minibatch_size=minibatch_size) sp = Speedometer() for t in range(n_steps): error = (state_d.x[0] - state_l.x[0]).mean() with hold_dbplots(draw_every=draw_every): dbplot(state_d.h[0], 'hd') dbplot(state_d.x[0], 'xd') dbplot(state_l.h[0], 'hl') dbplot(state_l.x[0], 'xl') dbplot(np.array([abs(net_l.w_hx).mean()]), 'wmag') dbplot(error, 'error') state_d = net_d.update(state_d) state_l = net_l.update( state_l, inp=state_d.x if cut_time is None or t < cut_time else None) if t % 100 == 0: print(f'Rate: {sp(t+1)} iter/s')
def run_plotting_server(address, port): """ Address and port to listen on. :param address: :param port: :return: """ # Get the first available socket starting from portand communicate it with the client who started this server sock, port = get_socket(address=address, port=port) write_port_to_file(port) max_number_clients = 100 max_plot_batch_size = 20000 sock.listen(max_number_clients) print(port) print("Plotting Server is listening") # We want to save and rescue the current plot in case the plotting server receives a signal.SIGINT (2) killer = GracefulKiller() # The plotting server receives input in a queue and returns the plot_ids as a way of communicating that it has rendered the plot main_input_queue = Queue.Queue() return_queue = Queue.Queue() # Start accepting clients' communication requests t0 = threading.Thread(target=handle_socket_accepts, args=(sock, main_input_queue, return_queue, max_number_clients)) t0.setDaemon(True) t0.start() # If killed, save the current figure atexit.register(save_current_figure) # Now, we can accept plots in the main thread! while True: if killer.kill_now: # The server has received a signal.SIGINT (2), so we stop receiving plots and terminate break # Retrieve data points that might have come in in the mean-time: client_messages = _queue_get_all_no_wait(main_input_queue, max_plot_batch_size) # client_messages is a list of ClientMessage objects if len(client_messages) > 0: return_values = [] with hold_dbplots(): for client_msg in client_messages: # For each ClientMessage object # Take apart the received message, plot, and return the plot_id to the client who sent it plot_message = pickle.loads( client_msg.dbplot_message ) # A DBPlotMessage object (see plotting_client.py) plot_message.dbplot_args['draw_now'] = False dbplot(**plot_message.dbplot_args) return_values.append( (client_msg.client_address, plot_message.plot_id)) for client, plot_id in return_values: return_queue.put([client, plot_id]) else: time.sleep(0.1)
def test_two_plots_in_the_same_axis_version_2(): reset_dbplot() # Option 2: Give both plots the same 'axis' argument for i in xrange(5): data = np.random.randn(200) x = np.linspace(-5, 5, 100) with hold_dbplots(): dbplot(data, 'histogram', plot_type='histogram', axis='hist') dbplot((x, 1./np.sqrt(2*np.pi*np.var(data)) * np.exp(-(x-np.mean(data))**2/(2*np.var(data)))), 'density', axis='hist', plot_type='line')
def test_two_plots_in_the_same_axis_version_1(): reset_dbplot() # Option 1: Name the 'axis' argument to the second plot after the name of the first for i in xrange(5): data = np.random.randn(200) x = np.linspace(-5, 5, 100) with hold_dbplots(): dbplot(data, 'histogram', plot_type='histogram') dbplot((x, 1./np.sqrt(2*np.pi*np.var(data)) * np.exp(-(x-np.mean(data))**2/(2*np.var(data)))), 'density', axis='histogram', plot_type='line')
def test_inline_custom_plots(): for t in range(10): with hold_dbplots(): x = np.sin(t / 10. + np.linspace(0, 10, 200)) dbplot(x, 'x', plot_type='line') use_dbplot_axis('custom', clear=True) plt.plot(x, label='x', linewidth=2) plt.plot(x**2, label='$x**2$', linewidth=2)
def test_bbox_display(): # It once was the case that bboxes failed when in a hold block with their image. Not any more. with hold_dbplots(): dbplot((np.random.rand(40, 40) * 255.999).astype(np.uint8), 'gfdsg') dbplot((np.random.rand(40, 40) * 255.999).astype(np.uint8), 'img') dbplot([10, 20, 25, 30], 'bbox', axis='img', plot_type=DBPlotTypes.BBOX)
def demo_dbplot(n_frames=1000): """ Demonstrates the various types of plots. The appropriate plot can usually be inferred from the first input data. In cases where there are multple ways to display the input data, you can use the plot_type argument. """ # Approximate frame rates: # Macbook Air, MacOSX backend, mode=safe ~2.4 FPS # Macbook Air, MacOSX backend, mode=fast: FAILS # Macbook Air, TkAgg backend, mode=safe: ~1.48 FPS # Macbook Air, TkAgg backend, mode=fast: ~4.4 FPS # Linux box, Qt4Agg backend, mode=safe: ~2.5 FPS # Linux box, Qt4Agg backend, mode=fast: (plot does not update) # set_dbplot_figure_size(150, 100) for i in xrange(n_frames): t_start = time.time() with hold_dbplots( ): # Sets it so that all plots update at once (rather than redrawing on each call, which is slower) dbplot(np.random.randn(20, 20), 'Greyscale Image') dbplot(np.random.randn(20, 20, 3), 'Colour Image') dbplot(np.random.randn(15, 20, 20), "Many Images") dbplot(np.random.randn(3, 6, 20, 20, 3), "Colour Image Grid") dbplot([ np.random.randn(15, 20, 3), np.random.randn(10, 10, 3), np.random.randn(10, 30, 3) ], "Differently Sized images") dbplot(np.random.randn(20, 2), 'Two Lines') dbplot( (np.linspace(-5, 5, 100) + np.sin(np.linspace(-5, 5, 100) * 2), np.linspace(-5, 5, 100)), '(X,Y) Lines') dbplot([np.sin(i / 20.), np.sin(i / 15.)], 'Moving Point History') dbplot([np.sin(i / 20.), np.sin(i / 15.)], 'Bounded memory history', plot_type=partial(MovingPointPlot, buffer_len=10)) dbplot((np.cos(i / 10.), np.sin(i / 11.)), '(X,Y) Moving Point History') dbplot( (np.cos(i / np.array([10., 15., 20.])), np.sin(i / np.array([11., 16., 21.]))) + np.array([0, 1, 2]), 'multi (X,Y) Moving Point History', plot_type='trajectory') dbplot(np.random.randn(20), 'Vector History') dbplot(np.random.randn(50), 'Histogram', plot_type='histogram') dbplot(np.random.randn(50), 'Cumulative Histogram', plot_type='cumhist') dbplot(('Veni', 'Vidi', 'Vici')[i % 3], 'text-history') dbplot(('Veni', 'Vidi', 'Vici')[i % 3], 'text-notice', plot_type='notice') if i % 10 == 0: print('Frame Rate: {:3g}FPS'.format(1. / (time.time() - t_start)))
def demo_dbplot(): """ Demonstrates the various types of plots. The appropriate plot can usually be inferred from the first input data. In cases where there are multple ways to display the input data, you can use the plot_type argument. """ from matplotlib import pyplot as plt plt.ion() set_dbplot_figure_size(15, 10) with EZProfiler('plot time') as prof: for i in xrange(1000): with hold_dbplots( ): # Sets it so that all plots update at once (rather than redrawing on each call, which is slower) dbplot(np.random.randn(20, 20), 'Greyscale Image') dbplot(np.random.randn(20, 20, 3), 'Colour Image') dbplot(np.random.randn(15, 20, 20), "Many Images") dbplot(np.random.randn(3, 6, 20, 20, 3), "Colour Image Grid") dbplot([ np.random.randn(15, 20, 3), np.random.randn(10, 10, 3), np.random.randn(10, 30, 3) ], "Differently Sized images") dbplot(np.random.randn(20, 2), 'Two Lines') dbplot((np.linspace(-5, 5, 100) + np.sin(np.linspace(-5, 5, 100) * 2), np.linspace(-5, 5, 100)), '(X,Y) Lines') dbplot([np.sin(i / 20.), np.sin(i / 15.)], 'Moving Point History') dbplot([np.sin(i / 20.), np.sin(i / 15.)], 'Bounded memory history', plot_type=lambda: MovingPointPlot(buffer_len=10)) dbplot( (np.sin(i / 5.) * (i + 100.), np.cos(i / 5.) * (i + 100.)), '(X,Y) Moving Point History') dbplot( (np.sin(np.array([i, i * 1.5, i * 2]) / 5.) * (i + 10.), np.cos(np.array([i, i * 1.5, i * 2]) / 5.) * (i + 10.)), 'multi (X,Y) Moving Point History', plot_type='trajectory') dbplot(np.random.randn(20), 'Vector History') dbplot(np.random.randn(50), 'Histogram', plot_type='histogram') dbplot(np.random.randn(50), 'Cumulative Histogram', plot_type='cumhist') dbplot(('Veni', 'Vidi', 'Vici')[i % 3], 'text-history') dbplot(('Veni', 'Vidi', 'Vici')[i % 3], 'text-notice', plot_type='notice') if i % 10 == 0: print 'Mean Frame Rate: %.3gFPS' % ( (i + 1) / prof.get_current_time(), )
def demo_temporal_mnist(n_samples=None, smoothing_steps=200): _, _, original_data, original_labels = get_mnist_dataset( n_training_samples=n_samples, n_test_samples=n_samples).xyxy _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset( n_training_samples=n_samples, n_test_samples=n_samples, smoothing_steps=smoothing_steps).xyxy for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data, temporal_labels): with hold_dbplots(): dbplot(ox, 'sample', title=str(oy)) dbplot(tx, 'smooth', title=str(ty))
def transaction_chooser(futures_stream, transaction_cost, initial_have_state=False): """ Given a stream which yields future forecast samples, generate transactions. :param futures_stream: :param transaction_cost: :param initial_have_state: :return: """ have_state = initial_have_state history = [] for t, (x0, futures) in enumerate(futures_stream): # Futures is an (n_samples, n_steps) array of futures, with futures[:, 0] corresponding to n_samples copies of the current time-step. expected_future = np.mean(futures, axis=0) # (samples, time) history.append(x0) with hold_dbplots(draw_every='0.05s'): data = np.concatenate( [[np.array(history)] * len(futures), futures], axis=1) dbplot(data.T, 'futures', plot_type=('line', dict(color='C0', axes_update_mode='expand'))) if not have_state: t_sell = next( (tau for tau, m in enumerate(expected_future) if m > x0 + transaction_cost), None) # Get next meeting criterion or None if it is never met. if t_sell is not None: if t_sell == 0 or expected_future[:t_sell].min() >= x0: print(f'Buying at t={t}, for ${x0:.3g}') mark_trade(trade=Trade(time=t, price=x0, type=TradeTypes.BUY), ax=use_dbplot_axis('futures')) have_state = True else: t_buy = next((tau for tau, m in enumerate(expected_future) if m < x0 - transaction_cost), None) if t_buy is not None: if t_buy == 0 or not expected_future[:t_buy].max() > x0: print(f'Selling at t={t}, for ${x0:.3g}') mark_trade(trade=Trade(time=t, price=x0, type=TradeTypes.SELL), ax=use_dbplot_axis('futures')) have_state = False
def test_moving_point_multiple_points(): reset_dbplot() p1 = 5. p2 = 8. for i in xrange(50): with hold_dbplots(draw_every=5): dbplot(np.sin([i / p1, i / p2]), 'unlim buffer', plot_type=partial(MovingPointPlot)) dbplot(np.sin([i / p1, i / p2]), 'lim buffer', plot_type=partial(MovingPointPlot, buffer_len=20)) dbplot( np.sin([i / p1, i / p2]), 'resampling buffer', plot_type=partial(ResamplingLineHistory, buffer_len=20) ) # Only looks bad because of really small buffer length from testing.
def transaction_chooser(futures_stream, transaction_cost, riskiness=0.1, initial_have_state=False): have_state = initial_have_state for t, futures in enumerate(futures_stream): expected_future = np.mean(futures, axis=0) # (samples, time) x0 = expected_future[0] with hold_dbplots(): dbplot(x0, 'x0') dbplot(futures.T, 'futures', plot_type='line') if not have_state: # check when and if there is a profitable sell moment t_sell = get_next_or_none( tau for tau, m in enumerate(expected_future) if m > x0 + transaction_cost and riskiness > -1 * expected_shortfall(x0, futures[:, tau])) if t_sell is not None: # ES = expected_shortfall(x0, futures[:, t_sell]) # check if until that sell moment arrives, there is a better buy moment if (t_sell == 1 or expected_future[1:t_sell].min() > x0): # and riskiness > (-1*ES): print('BUY BUY BUY') use_dbplot_axis('x0') mark_trade(trade_time=t, trade_price=x0, trade_type='buy') have_state = True else: # check if there is a moment when buying a new share is cheaper than keeping this one t_buy = get_next_or_none(tau for tau, m in enumerate(expected_future) if m < x0 - transaction_cost) # check if until that moment arrives, there is a better sell moment if t_buy is not None: if t_buy == 1 or not expected_future[1:t_buy].max() > x0: print('SELL SELL SELL') use_dbplot_axis('x0') mark_trade(trade_time=t, trade_price=x0, trade_type='sell') have_state = False
def demo_settling_dynamics(symmetric=False, n_hidden=50, n_out=3, minibatch_size=1, decay=0.05, scale=.4, hidden_act='tanh', output_act='lin', draw_every=10, n_steps=10000, seed=124): """ Here we use Predictive Coding and compare_learning_curves the convergence of a predictive-coded network to one without. """ net = Network.from_init(symmetric=symmetric, n_hidden=n_hidden, n_out=n_out, scale=scale, fh=hidden_act, fx=output_act, decay=decay, rng=seed) state = net.init_state(minibatch_size=minibatch_size) sp = Speedometer() for t in range(n_steps): state = net.update(state) if t % 100 == 0: print(f'Rate: {sp(t+1)} iter/s') with hold_dbplots(draw_every=draw_every): dbplot(state.h[0], 'Hidden Units', title='Hidden Units (b={})'.format(net.b_h)) dbplot(state.x[0], 'Y Units', title='Y Units (b={})'.format(net.b_h))
def demo_dbplot(n_frames = 1000): """ Demonstrates the various types of plots. The appropriate plot can usually be inferred from the first input data. In cases where there are multple ways to display the input data, you can use the plot_type argument. """ # Approximate frame rates: # Macbook Air, MacOSX backend, mode=safe ~2.4 FPS # Macbook Air, MacOSX backend, mode=fast: FAILS # Macbook Air, TkAgg backend, mode=safe: ~1.48 FPS # Macbook Air, TkAgg backend, mode=fast: ~4.4 FPS # Linux box, Qt4Agg backend, mode=safe: ~2.5 FPS # Linux box, Qt4Agg backend, mode=fast: (plot does not update) # set_dbplot_figure_size(150, 100) for i in xrange(n_frames): t_start = time.time() with hold_dbplots(): # Sets it so that all plots update at once (rather than redrawing on each call, which is slower) dbplot(np.random.randn(20, 20), 'Greyscale Image') dbplot(np.random.randn(20, 20, 3), 'Colour Image') dbplot(np.random.randn(15, 20, 20), "Many Images") dbplot(np.random.randn(3, 6, 20, 20, 3), "Colour Image Grid") dbplot([np.random.randn(15, 20, 3), np.random.randn(10, 10, 3), np.random.randn(10, 30, 3)], "Differently Sized images") dbplot(np.random.randn(20, 2), 'Two Lines') dbplot((np.linspace(-5, 5, 100)+np.sin(np.linspace(-5, 5, 100)*2), np.linspace(-5, 5, 100)), '(X,Y) Lines') dbplot([np.sin(i/20.), np.sin(i/15.)], 'Moving Point History') dbplot([np.sin(i/20.), np.sin(i/15.)], 'Bounded memory history', plot_type=partial(MovingPointPlot, buffer_len=10)) dbplot((np.cos(i/10.), np.sin(i/11.)), '(X,Y) Moving Point History') dbplot((np.cos(i/np.array([10., 15., 20.])), np.sin(i/np.array([11., 16., 21.])))+np.array([0, 1, 2]), 'multi (X,Y) Moving Point History', plot_type='trajectory') dbplot(np.random.randn(20), 'Vector History') dbplot(np.random.randn(50), 'Histogram', plot_type = 'histogram') dbplot(np.random.randn(50), 'Cumulative Histogram', plot_type = 'cumhist') dbplot(('Veni', 'Vidi', 'Vici')[i%3], 'text-history') dbplot(('Veni', 'Vidi', 'Vici')[i%3], 'text-notice', plot_type='notice') if i % 10 == 0: print('Frame Rate: {:3g}FPS'.format(1./(time.time() - t_start)))
def demo_converge_to_pareto_curve(layer_sizes=[100, 100, 100, 100], w_scales=[1, 1, 1], n_samples=100, learning_rate=0.01, n_epochs=100, minibatch_size=10, n_random_points_to_try=1000, random_scale_range=(1, 5), parametrization='log', computation_weights=np.logspace(-6, -3, 8), layerwise_scales=True, show_random_scales=True, error_loss='L1', hang_now=True, seed=1234): set_dbplot_default_layout('h') rng = np.random.RandomState(seed) ws = initialize_network_params(layer_sizes=layer_sizes, mag='xavier-relu', include_biases=False, rng=rng) ws = [w * s for w, s in izip_equal(ws, w_scales)] train_data = rng.randn(n_samples, layer_sizes[0]) _, true_out = quantized_forward_pass_cost_and_output( train_data, weights=ws, scales=None, quantization_method=None, seed=1234) # Run the random search scales_to_try = np.abs( rng.normal(loc=np.mean(random_scale_range), scale=np.diff(random_scale_range), size=(n_random_points_to_try, len(ws)))) if show_random_scales: ax = dbplot( scales_to_try.T, 'random_scales', axis='Scales', plot_type=lambda: LinePlot(plot_kwargs=dict(color=(.6, .6, .6)), make_legend=False), xlabel='Layer', ylabel='Scale') ax.set_xticks(np.arange(len(w_scales))) random_flop_counts, random_errors = compute_flop_errors_for_scales( train_data, scales_to_try, ws=ws, quantization_method='round', true_out=true_out, seed=1234) dbplot((random_flop_counts / 1e3 / len(train_data), random_errors), 'random_flop_errors', axis='Tradeoff', xlabel='kOps/sample', ylabel='Error', plot_type=lambda: LinePlot(plot_kwargs=dict( color=(.6, .6, .6), marker='.', linestyle=' '))) # Now run with optimization, across several values of K (total scale) for comp_weight in computation_weights: net = CompErrorScaleOptimizer(ws, optimizer=GradientDescent(learning_rate), comp_weight=comp_weight, layerwise_scales=layerwise_scales, hidden_activations='relu', output_activation='relu', parametrization=parametrization, rng=rng) f_train = net.train_scales.partial(error_loss=error_loss).compile() f_get_scales = net.get_scales.compile() for training_minibatch, iter_info in minibatch_iterate_info( train_data, minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=np.arange(0, n_epochs, 1)): if iter_info.test_now: ks = f_get_scales() with hold_dbplots(): if show_random_scales: dbplot(ks, 'solution_scales ' + str(comp_weight), axis='Scales', plot_type=lambda: LinePlot( plot_kwargs=dict(linewidth=3), make_legend=False, axes_update_mode='expand')) current_flop_counts, current_outputs = quantized_forward_pass_cost_and_output( train_data, weights=ws, scales=ks, quantization_method='round', seed=1234) current_error = np.abs(current_outputs - true_out).mean( ) / np.abs(true_out).mean() if np.isnan(current_error): print 'ERROR IS NAN!!!' dbplot((current_flop_counts / 1e3 / len(train_data), current_error), 'k=%.3g curve' % (comp_weight, ), axis='Tradeoff', plot_type=lambda: Moving2DPointPlot( legend_entries='$\\lambda=%.3g$' % comp_weight, axes_update_mode='expand', legend_entry_size=11)) f_train(training_minibatch) if hang_now: dbplot_hang()
def run_plotting_server(address, port, client_address, client_port): """ Address and port to listen on. :param address: :param port: :return: """ # Get the first available socket starting from portand communicate it with the client who started this server sock, port = get_socket(address=address, port=port) write_port_to_file(port) max_number_clients = 100 max_plot_batch_size = 2000 sock.listen(max_number_clients) one_time_send_to(address=client_address,port=client_port,message=str(port)) # We want to save and rescue the current plot in case the plotting server receives a signal.SIGINT (2) killer = GracefulKiller() # The plotting server receives input in a queue and returns the plot_ids as a way of communicating that it has rendered the plot main_input_queue = Queue.Queue() return_queue = Queue.Queue() # Start accepting clients' communication requests t0 = threading.Thread(target=handle_socket_accepts,args=(sock, main_input_queue, return_queue, max_number_clients)) t0.setDaemon(True) t0.start() # Received exp_dir on first db_plot_message? exp_dir_received = False # Now, we can accept plots in the main thread! set_dbplot_figure_size(9,10) while True: if killer.kill_now: sock.close() # will cause handle_socket_accepts thread to terminate # The server has received a signal.SIGINT (2), so we stop receiving plots and terminate break # Retrieve data points that might have come in in the mean-time: client_messages = _queue_get_all_no_wait(main_input_queue, max_plot_batch_size) # client_messages is a list of ClientMessage objects if len(client_messages) > 0: return_values = [] with hold_dbplots(): for client_msg in client_messages: # For each ClientMessage object # Take apart the received message, plot, and return the plot_id to the client who sent it plot_message = pickle.loads(client_msg.dbplot_message) # A DBPlotMessage object (see plotting_client.py) plot_message.dbplot_args['draw_now'] = False if not exp_dir_received: if "exp_dir" == plot_message.dbplot_args["name"]: atexit.register(save_current_figure,(plot_message.dbplot_args["data"])) exp_dir_received = True if len(client_messages) == 1: continue else: continue axis = dbplot(**plot_message.dbplot_args) axis.ticklabel_format(style='sci', useOffset=False) return_values.append((client_msg.client_address, plot_message.plot_id)) if not exp_dir_received: atexit.register(save_current_figure) exp_dir_received = True plt.rcParams.update({'axes.titlesize': 'small', 'axes.labelsize': 'small'}) plt.subplots_adjust(hspace=0.4,wspace=0.6) for client, plot_id in return_values: return_queue.put([client,plot_id]) else: time.sleep(0.1)
if __name__ == '__main__': from src.peters_stuff.sample_data import SampleImages from src.peters_stuff.image_crop_generator import iter_bbox_batches from artemis.plotting.db_plotting import dbplot, hold_dbplots, DBPlotTypes img = SampleImages.sistine_512() normscale = 0.25 dbplot(img, 'image') for i, bboxes in enumerate( iter_bbox_batches(image_shape=img.shape[:2], crop_size=(64, 64), batch_size=64, position_generator_constructor='normal', n_iter=None, normscale=normscale)): raw_image_crops, normed_image_crops, positions = get_normed_crops_and_position_tensors( img=img, bboxes=bboxes, scale=1. / normscale) with hold_dbplots(): dbplot(raw_image_crops, 'crops') for i, bbox in enumerate(bboxes): dbplot(bbox, f'bbox[{i}]', axis='image', plot_type=DBPlotTypes.BBOX_R) dbplot((positions[:, 0].numpy(), positions[:, 1].numpy()), plot_type=DBPlotTypes.SCATTER)
subpath = \ 'ILSVRC2015/Data/VID/snippets/test' if 'test' in identifier else \ 'ILSVRC2015/Data/VID/snippets/val' if 'val' in identifier else \ 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0001/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0001/', identifier + '.mp4')) else \ 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0002/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0002/', identifier + '.mp4')) else \ 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0003/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0003/', identifier + '.mp4')) else \ bad_value(identifier, 'Could not find identifier: {}'.format(identifier, )) print('Loading %s' % (identifier, )) full_path = get_file_in_archive( relative_path='data/ILSVRC2015', subpath=os.path.join(subpath, identifier+'.mp4'), url='http://vision.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID_snippets_final.tar.gz' ) video = smart_load_video(full_path, size=size, cut_edges=cut_edges, resize_mode=resize_mode, cut_edges_thresh=cut_edges_thresh) print('Done.') return video if __name__ == '__main__': import itertools from artemis.plotting.db_plotting import dbplot, hold_dbplots identifiers = ['ILSVRC2015_train_00033009', 'ILSVRC2015_train_00033010', 'ILSVRC2015_train_00763000', 'ILSVRC2015_test_00004002'] videos = [load_ilsvrc_video(identifier, size=(224, 224), cut_edges=True) for identifier in identifiers] for i in itertools.count(0): with hold_dbplots(): for identifier, vid in zip(identifiers, videos): dbplot(vid[i%len(vid)], identifier, title='%s: %s' % (identifier, i%len(vid)))
def demo_optimize_mnist_net(hidden_sizes=[200, 200], learning_rate=0.01, n_epochs=100, minibatch_size=10, parametrization='log', computation_weights=np.logspace(-6, -3, 8), layerwise_scales=True, show_scales=True, hidden_activations='relu', test_every=0.5, output_activation='softmax', error_loss='L1', comp_evaluation_calc='multiplyadds', smoothing_steps=1000, seed=1234): train_data, train_targets, test_data, test_targets = get_mnist_dataset( flat=True).to_onehot().xyxy params = train_conventional_mlp_on_mnist( hidden_sizes=hidden_sizes, hidden_activations=hidden_activations, output_activation=output_activation, rng=seed) weights, biases = params[::2], params[1::2] rng = get_rng(seed + 1) true_out = forward_pass(input_data=test_data, weights=weights, biases=biases, hidden_activations=hidden_activations, output_activation=output_activation) optimized_results = OrderedDict([]) optimized_results['unoptimized'] = get_mnist_results_with_parameters( weights=weights, biases=biases, scales=None, hidden_activations=hidden_activations, output_activation=output_activation, smoothing_steps=smoothing_steps) set_dbplot_figure_size(15, 10) for comp_weight in computation_weights: net = CompErrorScaleOptimizer(ws=weights, bs=biases, optimizer=GradientDescent(learning_rate), comp_weight=comp_weight, layerwise_scales=layerwise_scales, hidden_activations=hidden_activations, output_activation=output_activation, parametrization=parametrization, rng=rng) f_train = net.train_scales.partial(error_loss=error_loss).compile() f_get_scales = net.get_scales.compile() for training_minibatch, iter_info in minibatch_iterate_info( train_data, minibatch_size=minibatch_size, n_epochs=n_epochs, test_epochs=np.arange(0, n_epochs, test_every)): if iter_info.test_now: # Test the computation and all that ks = f_get_scales() print 'Epoch %.3g' % (iter_info.epoch, ) with hold_dbplots(): if show_scales: if layerwise_scales: dbplot(ks, '%s solution_scales' % (comp_weight, ), plot_type=lambda: LinePlot( plot_kwargs=dict(linewidth=3), make_legend=False, axes_update_mode='expand', y_bounds=(0, None)), axis='solution_scales', xlabel='layer', ylabel='scale') else: for i, k in enumerate(ks): dbplot(k, '%s solution_scales' % (i, ), plot_type=lambda: LinePlot( plot_kwargs=dict(linewidth=3), make_legend=False, axes_update_mode='expand', y_bounds=(0, None)), axis='solution_scales', xlabel='layer', ylabel='scale') current_flop_counts, current_outputs = quantized_forward_pass_cost_and_output( test_data, weights=weights, scales=ks, quantization_method='round', hidden_activations=hidden_activations, output_activation=output_activation, computation_calc=comp_evaluation_calc, seed=1234) current_error = np.abs(current_outputs - true_out).mean( ) / np.abs(true_out).mean() current_class_error = percent_argmax_incorrect( current_outputs, test_targets) if np.isnan(current_error): print 'ERROR IS NAN!!!' dbplot((current_flop_counts / 1e6, current_error), '%s error-curve' % (comp_weight, ), axis='error-curve', plot_type='trajectory+', xlabel='MFlops', ylabel='error') dbplot((current_flop_counts / 1e6, current_class_error), '%s class-curve' % (comp_weight, ), axis='class-curve', plot_type='trajectory+', xlabel='MFlops', ylabel='class-error') f_train(training_minibatch) optimized_results['lambda=%.3g' % (comp_weight, )] = get_mnist_results_with_parameters( weights=weights, biases=biases, scales=ks, hidden_activations=hidden_activations, output_activation=output_activation, smoothing_steps=smoothing_steps) return optimized_results