Example #1
0
def demo_temporal_mnist(n_samples = None, smoothing_steps = 200):
    _, _, original_data, original_labels = get_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples).xyxy
    _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples, smoothing_steps=smoothing_steps).xyxy
    for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data, temporal_labels):
        with hold_dbplots():
            dbplot(ox, 'sample', title = str(oy))
            dbplot(tx, 'smooth', title = str(ty))
Example #2
0
def test_periodic_plotting():

    for t in range(100):
        with hold_dbplots(draw_every='1s'):
            dbplot(np.sin(t/10), 'sinusoid')
            dbplot(np.cos(t/10), 'cosinusoid')
        time.sleep(0.02)
def demo_settling_dynamics(symmetric=False,
                           n_hidden=50,
                           n_out=3,
                           input_influence=0.01,
                           learning_rate=0.0001,
                           cut_time=None,
                           minibatch_size=1,
                           decay=0.05,
                           scale=.4,
                           hidden_act='tanh',
                           output_act='lin',
                           draw_every=10,
                           n_steps=10000,
                           seed=124):
    """
    Here we use Predictive Coding and compare_learning_curves the convergence of a predictive-coded network to one without.
    """

    rng = get_rng(seed)
    net_d = Network.from_init(symmetric=symmetric,
                              n_hidden=n_hidden,
                              n_out=n_out,
                              scale=scale,
                              fh=hidden_act,
                              fx=output_act,
                              decay=decay,
                              rng=rng)
    state_d = net_d.init_state(minibatch_size=minibatch_size)

    net_l = Network.from_init(symmetric=symmetric,
                              n_hidden=n_hidden,
                              n_out=n_out,
                              scale=scale,
                              fh=hidden_act,
                              fx=output_act,
                              decay=decay,
                              rng=rng,
                              input_influence=input_influence,
                              learning_rate=learning_rate)
    state_l = net_l.init_state(minibatch_size=minibatch_size)

    sp = Speedometer()
    for t in range(n_steps):

        error = (state_d.x[0] - state_l.x[0]).mean()
        with hold_dbplots(draw_every=draw_every):
            dbplot(state_d.h[0], 'hd')
            dbplot(state_d.x[0], 'xd')
            dbplot(state_l.h[0], 'hl')
            dbplot(state_l.x[0], 'xl')
            dbplot(np.array([abs(net_l.w_hx).mean()]), 'wmag')
            dbplot(error, 'error')

        state_d = net_d.update(state_d)
        state_l = net_l.update(
            state_l,
            inp=state_d.x if cut_time is None or t < cut_time else None)

        if t % 100 == 0:
            print(f'Rate: {sp(t+1)} iter/s')
Example #4
0
def run_plotting_server(address, port):
    """
    Address and port to listen on.
    :param address:
    :param port:
    :return:
    """

    # Get the first available socket starting from portand communicate it with the client who started this server
    sock, port = get_socket(address=address, port=port)
    write_port_to_file(port)
    max_number_clients = 100
    max_plot_batch_size = 20000
    sock.listen(max_number_clients)
    print(port)
    print("Plotting Server is listening")

    # We want to save and rescue the current plot in case the plotting server receives a signal.SIGINT (2)
    killer = GracefulKiller()

    # The plotting server receives input in a queue and returns the plot_ids as a way of communicating that it has rendered the plot
    main_input_queue = Queue.Queue()
    return_queue = Queue.Queue()
    # Start accepting clients' communication requests
    t0 = threading.Thread(target=handle_socket_accepts,
                          args=(sock, main_input_queue, return_queue,
                                max_number_clients))
    t0.setDaemon(True)
    t0.start()

    # If killed, save the current figure
    atexit.register(save_current_figure)

    # Now, we can accept plots in the main thread!
    while True:
        if killer.kill_now:
            # The server has received a signal.SIGINT (2), so we stop receiving plots and terminate
            break
        # Retrieve data points that might have come in in the mean-time:
        client_messages = _queue_get_all_no_wait(main_input_queue,
                                                 max_plot_batch_size)
        # client_messages is a list of ClientMessage objects
        if len(client_messages) > 0:
            return_values = []
            with hold_dbplots():
                for client_msg in client_messages:  # For each ClientMessage object
                    # Take apart the received message, plot, and return the plot_id to the client who sent it
                    plot_message = pickle.loads(
                        client_msg.dbplot_message
                    )  # A DBPlotMessage object (see plotting_client.py)
                    plot_message.dbplot_args['draw_now'] = False
                    dbplot(**plot_message.dbplot_args)
                    return_values.append(
                        (client_msg.client_address, plot_message.plot_id))
            for client, plot_id in return_values:
                return_queue.put([client, plot_id])
        else:
            time.sleep(0.1)
Example #5
0
def test_two_plots_in_the_same_axis_version_2():
    reset_dbplot()
    # Option 2: Give both plots the same 'axis' argument
    for i in xrange(5):
        data = np.random.randn(200)
        x = np.linspace(-5, 5, 100)
        with hold_dbplots():
            dbplot(data, 'histogram', plot_type='histogram', axis='hist')
            dbplot((x, 1./np.sqrt(2*np.pi*np.var(data)) * np.exp(-(x-np.mean(data))**2/(2*np.var(data)))), 'density', axis='hist', plot_type='line')
Example #6
0
def test_two_plots_in_the_same_axis_version_1():
    reset_dbplot()
    # Option 1: Name the 'axis' argument to the second plot after the name of the first
    for i in xrange(5):
        data = np.random.randn(200)
        x = np.linspace(-5, 5, 100)
        with hold_dbplots():
            dbplot(data, 'histogram', plot_type='histogram')
            dbplot((x, 1./np.sqrt(2*np.pi*np.var(data)) * np.exp(-(x-np.mean(data))**2/(2*np.var(data)))), 'density', axis='histogram', plot_type='line')
def test_inline_custom_plots():

    for t in range(10):
        with hold_dbplots():
            x = np.sin(t / 10. + np.linspace(0, 10, 200))
            dbplot(x, 'x', plot_type='line')
            use_dbplot_axis('custom', clear=True)
            plt.plot(x, label='x', linewidth=2)
            plt.plot(x**2, label='$x**2$', linewidth=2)
def test_bbox_display():

    # It once was the case that bboxes failed when in a hold block with their image.  Not any more.
    with hold_dbplots():
        dbplot((np.random.rand(40, 40) * 255.999).astype(np.uint8), 'gfdsg')
        dbplot((np.random.rand(40, 40) * 255.999).astype(np.uint8), 'img')
        dbplot([10, 20, 25, 30],
               'bbox',
               axis='img',
               plot_type=DBPlotTypes.BBOX)
Example #9
0
def demo_dbplot(n_frames=1000):
    """
    Demonstrates the various types of plots.

    The appropriate plot can usually be inferred from the first input data.  In cases where there are multple ways to
    display the input data, you can use the plot_type argument.
    """
    # Approximate frame rates:
    # Macbook Air, MacOSX backend, mode=safe ~2.4 FPS
    # Macbook Air, MacOSX backend, mode=fast:  FAILS
    # Macbook Air, TkAgg backend, mode=safe: ~1.48 FPS
    # Macbook Air, TkAgg backend, mode=fast: ~4.4 FPS
    # Linux box, Qt4Agg backend, mode=safe: ~2.5 FPS
    # Linux box, Qt4Agg backend, mode=fast: (plot does not update)

    # set_dbplot_figure_size(150, 100)
    for i in xrange(n_frames):
        t_start = time.time()
        with hold_dbplots(
        ):  # Sets it so that all plots update at once (rather than redrawing on each call, which is slower)
            dbplot(np.random.randn(20, 20), 'Greyscale Image')
            dbplot(np.random.randn(20, 20, 3), 'Colour Image')
            dbplot(np.random.randn(15, 20, 20), "Many Images")
            dbplot(np.random.randn(3, 6, 20, 20, 3), "Colour Image Grid")
            dbplot([
                np.random.randn(15, 20, 3),
                np.random.randn(10, 10, 3),
                np.random.randn(10, 30, 3)
            ], "Differently Sized images")
            dbplot(np.random.randn(20, 2), 'Two Lines')
            dbplot(
                (np.linspace(-5, 5, 100) + np.sin(np.linspace(-5, 5, 100) * 2),
                 np.linspace(-5, 5, 100)), '(X,Y) Lines')
            dbplot([np.sin(i / 20.), np.sin(i / 15.)], 'Moving Point History')
            dbplot([np.sin(i / 20.), np.sin(i / 15.)],
                   'Bounded memory history',
                   plot_type=partial(MovingPointPlot, buffer_len=10))
            dbplot((np.cos(i / 10.), np.sin(i / 11.)),
                   '(X,Y) Moving Point History')
            dbplot(
                (np.cos(i / np.array([10., 15., 20.])),
                 np.sin(i / np.array([11., 16., 21.]))) + np.array([0, 1, 2]),
                'multi (X,Y) Moving Point History',
                plot_type='trajectory')
            dbplot(np.random.randn(20), 'Vector History')
            dbplot(np.random.randn(50), 'Histogram', plot_type='histogram')
            dbplot(np.random.randn(50),
                   'Cumulative Histogram',
                   plot_type='cumhist')
            dbplot(('Veni', 'Vidi', 'Vici')[i % 3], 'text-history')
            dbplot(('Veni', 'Vidi', 'Vici')[i % 3],
                   'text-notice',
                   plot_type='notice')
        if i % 10 == 0:
            print('Frame Rate: {:3g}FPS'.format(1. / (time.time() - t_start)))
Example #10
0
def demo_dbplot():
    """
    Demonstrates the various types of plots.

    The appropriate plot can usually be inferred from the first input data.  In cases where there are multple ways to
    display the input data, you can use the plot_type argument.
    """
    from matplotlib import pyplot as plt
    plt.ion()

    set_dbplot_figure_size(15, 10)
    with EZProfiler('plot time') as prof:
        for i in xrange(1000):
            with hold_dbplots(
            ):  # Sets it so that all plots update at once (rather than redrawing on each call, which is slower)
                dbplot(np.random.randn(20, 20), 'Greyscale Image')
                dbplot(np.random.randn(20, 20, 3), 'Colour Image')
                dbplot(np.random.randn(15, 20, 20), "Many Images")
                dbplot(np.random.randn(3, 6, 20, 20, 3), "Colour Image Grid")
                dbplot([
                    np.random.randn(15, 20, 3),
                    np.random.randn(10, 10, 3),
                    np.random.randn(10, 30, 3)
                ], "Differently Sized images")
                dbplot(np.random.randn(20, 2), 'Two Lines')
                dbplot((np.linspace(-5, 5, 100) +
                        np.sin(np.linspace(-5, 5, 100) * 2),
                        np.linspace(-5, 5, 100)), '(X,Y) Lines')
                dbplot([np.sin(i / 20.), np.sin(i / 15.)],
                       'Moving Point History')
                dbplot([np.sin(i / 20.), np.sin(i / 15.)],
                       'Bounded memory history',
                       plot_type=lambda: MovingPointPlot(buffer_len=10))
                dbplot(
                    (np.sin(i / 5.) * (i + 100.), np.cos(i / 5.) * (i + 100.)),
                    '(X,Y) Moving Point History')
                dbplot(
                    (np.sin(np.array([i, i * 1.5, i * 2]) / 5.) * (i + 10.),
                     np.cos(np.array([i, i * 1.5, i * 2]) / 5.) * (i + 10.)),
                    'multi (X,Y) Moving Point History',
                    plot_type='trajectory')
                dbplot(np.random.randn(20), 'Vector History')
                dbplot(np.random.randn(50), 'Histogram', plot_type='histogram')
                dbplot(np.random.randn(50),
                       'Cumulative Histogram',
                       plot_type='cumhist')
                dbplot(('Veni', 'Vidi', 'Vici')[i % 3], 'text-history')
                dbplot(('Veni', 'Vidi', 'Vici')[i % 3],
                       'text-notice',
                       plot_type='notice')

            if i % 10 == 0:
                print 'Mean Frame Rate: %.3gFPS' % (
                    (i + 1) / prof.get_current_time(), )
Example #11
0
def demo_temporal_mnist(n_samples=None, smoothing_steps=200):
    _, _, original_data, original_labels = get_mnist_dataset(
        n_training_samples=n_samples, n_test_samples=n_samples).xyxy
    _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(
        n_training_samples=n_samples,
        n_test_samples=n_samples,
        smoothing_steps=smoothing_steps).xyxy
    for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data,
                              temporal_labels):
        with hold_dbplots():
            dbplot(ox, 'sample', title=str(oy))
            dbplot(tx, 'smooth', title=str(ty))
Example #12
0
def transaction_chooser(futures_stream,
                        transaction_cost,
                        initial_have_state=False):
    """
    Given a stream which yields future forecast samples, generate transactions.
    :param futures_stream:
    :param transaction_cost:
    :param initial_have_state:
    :return:
    """
    have_state = initial_have_state
    history = []
    for t, (x0, futures) in enumerate(futures_stream):
        # Futures is an (n_samples, n_steps) array of futures, with futures[:, 0] corresponding to n_samples copies of the current time-step.
        expected_future = np.mean(futures, axis=0)  # (samples, time)
        history.append(x0)

        with hold_dbplots(draw_every='0.05s'):
            data = np.concatenate(
                [[np.array(history)] * len(futures), futures], axis=1)
            dbplot(data.T,
                   'futures',
                   plot_type=('line',
                              dict(color='C0', axes_update_mode='expand')))

        if not have_state:
            t_sell = next(
                (tau for tau, m in enumerate(expected_future)
                 if m > x0 + transaction_cost),
                None)  # Get next meeting criterion or None if it is never met.
            if t_sell is not None:
                if t_sell == 0 or expected_future[:t_sell].min() >= x0:
                    print(f'Buying at t={t}, for ${x0:.3g}')
                    mark_trade(trade=Trade(time=t,
                                           price=x0,
                                           type=TradeTypes.BUY),
                               ax=use_dbplot_axis('futures'))
                    have_state = True
        else:
            t_buy = next((tau for tau, m in enumerate(expected_future)
                          if m < x0 - transaction_cost), None)
            if t_buy is not None:
                if t_buy == 0 or not expected_future[:t_buy].max() > x0:
                    print(f'Selling at t={t}, for ${x0:.3g}')
                    mark_trade(trade=Trade(time=t,
                                           price=x0,
                                           type=TradeTypes.SELL),
                               ax=use_dbplot_axis('futures'))
                    have_state = False
def test_moving_point_multiple_points():
    reset_dbplot()
    p1 = 5.
    p2 = 8.
    for i in xrange(50):
        with hold_dbplots(draw_every=5):
            dbplot(np.sin([i / p1, i / p2]),
                   'unlim buffer',
                   plot_type=partial(MovingPointPlot))
            dbplot(np.sin([i / p1, i / p2]),
                   'lim buffer',
                   plot_type=partial(MovingPointPlot, buffer_len=20))
            dbplot(
                np.sin([i / p1, i / p2]),
                'resampling buffer',
                plot_type=partial(ResamplingLineHistory, buffer_len=20)
            )  # Only looks bad because of really small buffer length from testing.
Example #14
0
def transaction_chooser(futures_stream,
                        transaction_cost,
                        riskiness=0.1,
                        initial_have_state=False):
    have_state = initial_have_state

    for t, futures in enumerate(futures_stream):
        expected_future = np.mean(futures, axis=0)  # (samples, time)
        x0 = expected_future[0]

        with hold_dbplots():
            dbplot(x0, 'x0')
            dbplot(futures.T, 'futures', plot_type='line')

        if not have_state:
            # check when and if there is a profitable sell moment
            t_sell = get_next_or_none(
                tau for tau, m in enumerate(expected_future)
                if m > x0 + transaction_cost and riskiness > -1 *
                expected_shortfall(x0, futures[:, tau]))

            if t_sell is not None:
                # ES = expected_shortfall(x0, futures[:, t_sell])
                # check if until that sell moment arrives, there is a better buy moment
                if (t_sell == 1 or expected_future[1:t_sell].min() >
                        x0):  # and riskiness > (-1*ES):
                    print('BUY BUY BUY')
                    use_dbplot_axis('x0')
                    mark_trade(trade_time=t, trade_price=x0, trade_type='buy')
                    have_state = True

        else:
            # check if there is a moment when buying a new share is cheaper than keeping this one
            t_buy = get_next_or_none(tau
                                     for tau, m in enumerate(expected_future)
                                     if m < x0 - transaction_cost)
            # check if until that moment arrives, there is a better sell moment
            if t_buy is not None:
                if t_buy == 1 or not expected_future[1:t_buy].max() > x0:
                    print('SELL SELL SELL')
                    use_dbplot_axis('x0')
                    mark_trade(trade_time=t, trade_price=x0, trade_type='sell')
                    have_state = False
def demo_settling_dynamics(symmetric=False,
                           n_hidden=50,
                           n_out=3,
                           minibatch_size=1,
                           decay=0.05,
                           scale=.4,
                           hidden_act='tanh',
                           output_act='lin',
                           draw_every=10,
                           n_steps=10000,
                           seed=124):
    """
    Here we use Predictive Coding and compare_learning_curves the convergence of a predictive-coded network to one without.
    """

    net = Network.from_init(symmetric=symmetric,
                            n_hidden=n_hidden,
                            n_out=n_out,
                            scale=scale,
                            fh=hidden_act,
                            fx=output_act,
                            decay=decay,
                            rng=seed)

    state = net.init_state(minibatch_size=minibatch_size)

    sp = Speedometer()
    for t in range(n_steps):
        state = net.update(state)

        if t % 100 == 0:
            print(f'Rate: {sp(t+1)} iter/s')

        with hold_dbplots(draw_every=draw_every):

            dbplot(state.h[0],
                   'Hidden Units',
                   title='Hidden Units (b={})'.format(net.b_h))
            dbplot(state.x[0],
                   'Y Units',
                   title='Y Units (b={})'.format(net.b_h))
Example #16
0
def demo_dbplot(n_frames = 1000):
    """
    Demonstrates the various types of plots.

    The appropriate plot can usually be inferred from the first input data.  In cases where there are multple ways to
    display the input data, you can use the plot_type argument.
    """
    # Approximate frame rates:
    # Macbook Air, MacOSX backend, mode=safe ~2.4 FPS
    # Macbook Air, MacOSX backend, mode=fast:  FAILS
    # Macbook Air, TkAgg backend, mode=safe: ~1.48 FPS
    # Macbook Air, TkAgg backend, mode=fast: ~4.4 FPS
    # Linux box, Qt4Agg backend, mode=safe: ~2.5 FPS
    # Linux box, Qt4Agg backend, mode=fast: (plot does not update)

    # set_dbplot_figure_size(150, 100)
    for i in xrange(n_frames):
        t_start = time.time()
        with hold_dbplots():  # Sets it so that all plots update at once (rather than redrawing on each call, which is slower)
            dbplot(np.random.randn(20, 20), 'Greyscale Image')
            dbplot(np.random.randn(20, 20, 3), 'Colour Image')
            dbplot(np.random.randn(15, 20, 20), "Many Images")
            dbplot(np.random.randn(3, 6, 20, 20, 3), "Colour Image Grid")
            dbplot([np.random.randn(15, 20, 3), np.random.randn(10, 10, 3), np.random.randn(10, 30, 3)], "Differently Sized images")
            dbplot(np.random.randn(20, 2), 'Two Lines')
            dbplot((np.linspace(-5, 5, 100)+np.sin(np.linspace(-5, 5, 100)*2), np.linspace(-5, 5, 100)), '(X,Y) Lines')
            dbplot([np.sin(i/20.), np.sin(i/15.)], 'Moving Point History')
            dbplot([np.sin(i/20.), np.sin(i/15.)], 'Bounded memory history', plot_type=partial(MovingPointPlot, buffer_len=10))
            dbplot((np.cos(i/10.), np.sin(i/11.)), '(X,Y) Moving Point History')
            dbplot((np.cos(i/np.array([10., 15., 20.])), np.sin(i/np.array([11., 16., 21.])))+np.array([0, 1, 2]), 'multi (X,Y) Moving Point History', plot_type='trajectory')
            dbplot(np.random.randn(20), 'Vector History')
            dbplot(np.random.randn(50), 'Histogram', plot_type = 'histogram')
            dbplot(np.random.randn(50), 'Cumulative Histogram', plot_type = 'cumhist')
            dbplot(('Veni', 'Vidi', 'Vici')[i%3], 'text-history')
            dbplot(('Veni', 'Vidi', 'Vici')[i%3], 'text-notice', plot_type='notice')
        if i % 10 == 0:
            print('Frame Rate: {:3g}FPS'.format(1./(time.time() - t_start)))
Example #17
0
def demo_converge_to_pareto_curve(layer_sizes=[100, 100, 100, 100],
                                  w_scales=[1, 1, 1],
                                  n_samples=100,
                                  learning_rate=0.01,
                                  n_epochs=100,
                                  minibatch_size=10,
                                  n_random_points_to_try=1000,
                                  random_scale_range=(1, 5),
                                  parametrization='log',
                                  computation_weights=np.logspace(-6, -3, 8),
                                  layerwise_scales=True,
                                  show_random_scales=True,
                                  error_loss='L1',
                                  hang_now=True,
                                  seed=1234):

    set_dbplot_default_layout('h')

    rng = np.random.RandomState(seed)
    ws = initialize_network_params(layer_sizes=layer_sizes,
                                   mag='xavier-relu',
                                   include_biases=False,
                                   rng=rng)
    ws = [w * s for w, s in izip_equal(ws, w_scales)]
    train_data = rng.randn(n_samples, layer_sizes[0])
    _, true_out = quantized_forward_pass_cost_and_output(
        train_data,
        weights=ws,
        scales=None,
        quantization_method=None,
        seed=1234)

    # Run the random search
    scales_to_try = np.abs(
        rng.normal(loc=np.mean(random_scale_range),
                   scale=np.diff(random_scale_range),
                   size=(n_random_points_to_try, len(ws))))

    if show_random_scales:
        ax = dbplot(
            scales_to_try.T,
            'random_scales',
            axis='Scales',
            plot_type=lambda: LinePlot(plot_kwargs=dict(color=(.6, .6, .6)),
                                       make_legend=False),
            xlabel='Layer',
            ylabel='Scale')
    ax.set_xticks(np.arange(len(w_scales)))

    random_flop_counts, random_errors = compute_flop_errors_for_scales(
        train_data,
        scales_to_try,
        ws=ws,
        quantization_method='round',
        true_out=true_out,
        seed=1234)
    dbplot((random_flop_counts / 1e3 / len(train_data), random_errors),
           'random_flop_errors',
           axis='Tradeoff',
           xlabel='kOps/sample',
           ylabel='Error',
           plot_type=lambda: LinePlot(plot_kwargs=dict(
               color=(.6, .6, .6), marker='.', linestyle=' ')))

    # Now run with optimization, across several values of K (total scale)
    for comp_weight in computation_weights:
        net = CompErrorScaleOptimizer(ws,
                                      optimizer=GradientDescent(learning_rate),
                                      comp_weight=comp_weight,
                                      layerwise_scales=layerwise_scales,
                                      hidden_activations='relu',
                                      output_activation='relu',
                                      parametrization=parametrization,
                                      rng=rng)
        f_train = net.train_scales.partial(error_loss=error_loss).compile()
        f_get_scales = net.get_scales.compile()
        for training_minibatch, iter_info in minibatch_iterate_info(
                train_data,
                minibatch_size=minibatch_size,
                n_epochs=n_epochs,
                test_epochs=np.arange(0, n_epochs, 1)):
            if iter_info.test_now:
                ks = f_get_scales()
                with hold_dbplots():
                    if show_random_scales:
                        dbplot(ks,
                               'solution_scales ' + str(comp_weight),
                               axis='Scales',
                               plot_type=lambda: LinePlot(
                                   plot_kwargs=dict(linewidth=3),
                                   make_legend=False,
                                   axes_update_mode='expand'))
                    current_flop_counts, current_outputs = quantized_forward_pass_cost_and_output(
                        train_data,
                        weights=ws,
                        scales=ks,
                        quantization_method='round',
                        seed=1234)
                    current_error = np.abs(current_outputs - true_out).mean(
                    ) / np.abs(true_out).mean()
                    if np.isnan(current_error):
                        print 'ERROR IS NAN!!!'
                    dbplot((current_flop_counts / 1e3 / len(train_data),
                            current_error),
                           'k=%.3g curve' % (comp_weight, ),
                           axis='Tradeoff',
                           plot_type=lambda: Moving2DPointPlot(
                               legend_entries='$\\lambda=%.3g$' % comp_weight,
                               axes_update_mode='expand',
                               legend_entry_size=11))
            f_train(training_minibatch)

    if hang_now:
        dbplot_hang()
Example #18
0
def run_plotting_server(address, port, client_address, client_port):
    """
    Address and port to listen on.
    :param address:
    :param port:
    :return:
    """

    # Get the first available socket starting from portand communicate it with the client who started this server
    sock, port = get_socket(address=address, port=port)
    write_port_to_file(port)
    max_number_clients = 100
    max_plot_batch_size = 2000
    sock.listen(max_number_clients)
    one_time_send_to(address=client_address,port=client_port,message=str(port))

    # We want to save and rescue the current plot in case the plotting server receives a signal.SIGINT (2)
    killer = GracefulKiller()

    # The plotting server receives input in a queue and returns the plot_ids as a way of communicating that it has rendered the plot
    main_input_queue = Queue.Queue()
    return_queue = Queue.Queue()
    # Start accepting clients' communication requests
    t0 = threading.Thread(target=handle_socket_accepts,args=(sock, main_input_queue, return_queue, max_number_clients))
    t0.setDaemon(True)
    t0.start()


    # Received exp_dir on first db_plot_message?
    exp_dir_received = False
    # Now, we can accept plots in the main thread!

    set_dbplot_figure_size(9,10)
    while True:
        if killer.kill_now:
            sock.close() # will cause handle_socket_accepts thread to terminate
            # The server has received a signal.SIGINT (2), so we stop receiving plots and terminate
            break
        # Retrieve data points that might have come in in the mean-time:
        client_messages = _queue_get_all_no_wait(main_input_queue, max_plot_batch_size)
        # client_messages is a list of ClientMessage objects
        if len(client_messages) > 0:
            return_values = []
            with hold_dbplots():
                for client_msg in client_messages:  # For each ClientMessage object
                    # Take apart the received message, plot, and return the plot_id to the client who sent it
                    plot_message = pickle.loads(client_msg.dbplot_message)  # A DBPlotMessage object (see plotting_client.py)
                    plot_message.dbplot_args['draw_now'] = False

                    if not exp_dir_received:
                        if "exp_dir" == plot_message.dbplot_args["name"]:
                            atexit.register(save_current_figure,(plot_message.dbplot_args["data"]))
                            exp_dir_received = True
                            if len(client_messages) == 1:
                                continue
                            else:
                                continue
                    axis = dbplot(**plot_message.dbplot_args)
                    axis.ticklabel_format(style='sci', useOffset=False)
                    return_values.append((client_msg.client_address, plot_message.plot_id))
                if not exp_dir_received:
                    atexit.register(save_current_figure)
                    exp_dir_received = True
                plt.rcParams.update({'axes.titlesize': 'small', 'axes.labelsize': 'small'})
                plt.subplots_adjust(hspace=0.4,wspace=0.6)
            for client, plot_id in return_values:
                return_queue.put([client,plot_id])
        else:
            time.sleep(0.1)
Example #19
0
if __name__ == '__main__':
    from src.peters_stuff.sample_data import SampleImages
    from src.peters_stuff.image_crop_generator import iter_bbox_batches

    from artemis.plotting.db_plotting import dbplot, hold_dbplots, DBPlotTypes

    img = SampleImages.sistine_512()
    normscale = 0.25

    dbplot(img, 'image')

    for i, bboxes in enumerate(
            iter_bbox_batches(image_shape=img.shape[:2],
                              crop_size=(64, 64),
                              batch_size=64,
                              position_generator_constructor='normal',
                              n_iter=None,
                              normscale=normscale)):

        raw_image_crops, normed_image_crops, positions = get_normed_crops_and_position_tensors(
            img=img, bboxes=bboxes, scale=1. / normscale)

        with hold_dbplots():
            dbplot(raw_image_crops, 'crops')
            for i, bbox in enumerate(bboxes):
                dbplot(bbox,
                       f'bbox[{i}]',
                       axis='image',
                       plot_type=DBPlotTypes.BBOX_R)
            dbplot((positions[:, 0].numpy(), positions[:, 1].numpy()),
                   plot_type=DBPlotTypes.SCATTER)
Example #20
0
    subpath = \
        'ILSVRC2015/Data/VID/snippets/test' if 'test' in identifier else \
        'ILSVRC2015/Data/VID/snippets/val' if 'val' in identifier else \
        'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0001/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0001/', identifier + '.mp4')) else \
        'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0002/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0002/', identifier + '.mp4')) else \
        'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0003/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0003/', identifier + '.mp4')) else \
        bad_value(identifier, 'Could not find identifier: {}'.format(identifier, ))

    print('Loading %s' % (identifier, ))
    full_path = get_file_in_archive(
        relative_path='data/ILSVRC2015',
        subpath=os.path.join(subpath, identifier+'.mp4'),
        url='http://vision.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID_snippets_final.tar.gz'
        )
    video = smart_load_video(full_path, size=size, cut_edges=cut_edges, resize_mode=resize_mode, cut_edges_thresh=cut_edges_thresh)
    print('Done.')
    return video


if __name__ == '__main__':
    import itertools
    from artemis.plotting.db_plotting import dbplot, hold_dbplots

    identifiers = ['ILSVRC2015_train_00033009', 'ILSVRC2015_train_00033010', 'ILSVRC2015_train_00763000', 'ILSVRC2015_test_00004002']
    videos = [load_ilsvrc_video(identifier, size=(224, 224), cut_edges=True) for identifier in identifiers]

    for i in itertools.count(0):
        with hold_dbplots():
            for identifier, vid in zip(identifiers, videos):
                dbplot(vid[i%len(vid)], identifier, title='%s: %s' % (identifier, i%len(vid)))
def demo_optimize_mnist_net(hidden_sizes=[200, 200],
                            learning_rate=0.01,
                            n_epochs=100,
                            minibatch_size=10,
                            parametrization='log',
                            computation_weights=np.logspace(-6, -3, 8),
                            layerwise_scales=True,
                            show_scales=True,
                            hidden_activations='relu',
                            test_every=0.5,
                            output_activation='softmax',
                            error_loss='L1',
                            comp_evaluation_calc='multiplyadds',
                            smoothing_steps=1000,
                            seed=1234):

    train_data, train_targets, test_data, test_targets = get_mnist_dataset(
        flat=True).to_onehot().xyxy

    params = train_conventional_mlp_on_mnist(
        hidden_sizes=hidden_sizes,
        hidden_activations=hidden_activations,
        output_activation=output_activation,
        rng=seed)
    weights, biases = params[::2], params[1::2]

    rng = get_rng(seed + 1)

    true_out = forward_pass(input_data=test_data,
                            weights=weights,
                            biases=biases,
                            hidden_activations=hidden_activations,
                            output_activation=output_activation)
    optimized_results = OrderedDict([])
    optimized_results['unoptimized'] = get_mnist_results_with_parameters(
        weights=weights,
        biases=biases,
        scales=None,
        hidden_activations=hidden_activations,
        output_activation=output_activation,
        smoothing_steps=smoothing_steps)

    set_dbplot_figure_size(15, 10)
    for comp_weight in computation_weights:
        net = CompErrorScaleOptimizer(ws=weights,
                                      bs=biases,
                                      optimizer=GradientDescent(learning_rate),
                                      comp_weight=comp_weight,
                                      layerwise_scales=layerwise_scales,
                                      hidden_activations=hidden_activations,
                                      output_activation=output_activation,
                                      parametrization=parametrization,
                                      rng=rng)
        f_train = net.train_scales.partial(error_loss=error_loss).compile()
        f_get_scales = net.get_scales.compile()
        for training_minibatch, iter_info in minibatch_iterate_info(
                train_data,
                minibatch_size=minibatch_size,
                n_epochs=n_epochs,
                test_epochs=np.arange(0, n_epochs, test_every)):
            if iter_info.test_now:  # Test the computation and all that
                ks = f_get_scales()
                print 'Epoch %.3g' % (iter_info.epoch, )
                with hold_dbplots():
                    if show_scales:
                        if layerwise_scales:
                            dbplot(ks,
                                   '%s solution_scales' % (comp_weight, ),
                                   plot_type=lambda: LinePlot(
                                       plot_kwargs=dict(linewidth=3),
                                       make_legend=False,
                                       axes_update_mode='expand',
                                       y_bounds=(0, None)),
                                   axis='solution_scales',
                                   xlabel='layer',
                                   ylabel='scale')
                        else:
                            for i, k in enumerate(ks):
                                dbplot(k,
                                       '%s solution_scales' % (i, ),
                                       plot_type=lambda: LinePlot(
                                           plot_kwargs=dict(linewidth=3),
                                           make_legend=False,
                                           axes_update_mode='expand',
                                           y_bounds=(0, None)),
                                       axis='solution_scales',
                                       xlabel='layer',
                                       ylabel='scale')
                    current_flop_counts, current_outputs = quantized_forward_pass_cost_and_output(
                        test_data,
                        weights=weights,
                        scales=ks,
                        quantization_method='round',
                        hidden_activations=hidden_activations,
                        output_activation=output_activation,
                        computation_calc=comp_evaluation_calc,
                        seed=1234)
                    current_error = np.abs(current_outputs - true_out).mean(
                    ) / np.abs(true_out).mean()
                    current_class_error = percent_argmax_incorrect(
                        current_outputs, test_targets)
                    if np.isnan(current_error):
                        print 'ERROR IS NAN!!!'
                    dbplot((current_flop_counts / 1e6, current_error),
                           '%s error-curve' % (comp_weight, ),
                           axis='error-curve',
                           plot_type='trajectory+',
                           xlabel='MFlops',
                           ylabel='error')
                    dbplot((current_flop_counts / 1e6, current_class_error),
                           '%s class-curve' % (comp_weight, ),
                           axis='class-curve',
                           plot_type='trajectory+',
                           xlabel='MFlops',
                           ylabel='class-error')
            f_train(training_minibatch)
        optimized_results['lambda=%.3g' %
                          (comp_weight, )] = get_mnist_results_with_parameters(
                              weights=weights,
                              biases=biases,
                              scales=ks,
                              hidden_activations=hidden_activations,
                              output_activation=output_activation,
                              smoothing_steps=smoothing_steps)
    return optimized_results