def test_plotting_server():
    setup_web_plotting()

    for i in xrange(5):
        dbplot(np.random.randn(10, 10, 3), 'noise')
        dbplot(np.random.randn(20, 2), 'lines')
        plt.pause(.01)
Beispiel #2
0
def moving_point_plot(n_steps=20):
    for i in xrange(n_steps):
        data = np.array([i, i**2])
        if i == 4:
            data = np.array([float('nan'), i**2])
        dbplot(data, "history", plot_type=lambda: MovingPointPlot())
        time.sleep(0.5)
Beispiel #3
0
def test_particular_plot(n_steps=3):
    reset_dbplot()

    for i in xrange(n_steps):
        r = np.random.randn(1)
        dbplot(r,
               plot_type=partial(HistogramPlot, edges=np.linspace(-5, 5, 20)))
Beispiel #4
0
 def random_walk():
     data = 0
     for i in xrange(10):
         data += np.random.randn()
         dbplot(
             data, 'walk'
         )  #, plot_type=lambda: MovingPointPlot(axes_update_mode='expand'))
def demo_temporal_mnist(n_samples = None, smoothing_steps = 200):
    _, _, original_data, original_labels = get_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples).xyxy
    _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(n_training_samples=n_samples, n_test_samples=n_samples, smoothing_steps=smoothing_steps).xyxy
    for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data, temporal_labels):
        with hold_dbplots():
            dbplot(ox, 'sample', title = str(oy))
            dbplot(tx, 'smooth', title = str(ty))
Beispiel #6
0
def test_particular_plot(n_steps=3):

    for i in xrange(n_steps):
        r = np.random.randn(1)
        dbplot(r,
               plot_type=lambda: HistogramPlot(edges=np.linspace(-5, 5, 20)))
    clear_dbplot()
def test_periodic_plotting():

    for t in range(100):
        with hold_dbplots(draw_every='1s'):
            dbplot(np.sin(t/10), 'sinusoid')
            dbplot(np.cos(t/10), 'cosinusoid')
        time.sleep(0.02)
Beispiel #8
0
 def vis_callback(xx):
     p = predictor.symbolic_predictor._function
     in_layer = {
         'Layer[0].w': p.layers[0].linear_transform._w.get_value().T.reshape(-1, 28, 28),
         'Layer[0].b': p.layers[0].linear_transform._b.get_value(),
         }
     other_layers = [{'Layer[%s].w' % (i+1): l.linear_transform._w.get_value(), 'Layer[%s].b' % (i+1): l.linear_transform._b.get_value()} for i, l in enumerate(p.layers[1:])]
     dbplot(dict(in_layer.items() + sum([o.items() for o in other_layers], [])))
Beispiel #9
0
def test_list_of_images():
    reset_dbplot()
    for _ in xrange(2):
        dbplot([
            np.random.randn(12, 30),
            np.random.randn(10, 10),
            np.random.randn(15, 10)
        ])
Beispiel #10
0
def demo_plot_temporal_mnist(n_rows=8, n_cols=16, smoothing_steps=1000):
    _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(
        smoothing_steps=smoothing_steps).xyxy
    stride = len(temporal_data) / n_rows
    starts = np.arange(0, stride * n_rows, stride)
    data = np.array([temporal_data[s:s + n_cols]
                     for s in starts]).swapaxes(0, 1)
    dbplot(data, 'Temporal MNIST', plot_type='pic', hang=True)
def run_plotting_server(address, port):
    """
    Address and port to listen on.
    :param address:
    :param port:
    :return:
    """

    # Get the first available socket starting from portand communicate it with the client who started this server
    sock, port = get_socket(address=address, port=port)
    write_port_to_file(port)
    max_number_clients = 100
    max_plot_batch_size = 20000
    sock.listen(max_number_clients)
    print(port)
    print("Plotting Server is listening")

    # We want to save and rescue the current plot in case the plotting server receives a signal.SIGINT (2)
    killer = GracefulKiller()

    # The plotting server receives input in a queue and returns the plot_ids as a way of communicating that it has rendered the plot
    main_input_queue = Queue.Queue()
    return_queue = Queue.Queue()
    # Start accepting clients' communication requests
    t0 = threading.Thread(target=handle_socket_accepts,
                          args=(sock, main_input_queue, return_queue,
                                max_number_clients))
    t0.setDaemon(True)
    t0.start()

    # If killed, save the current figure
    atexit.register(save_current_figure)

    # Now, we can accept plots in the main thread!
    while True:
        if killer.kill_now:
            # The server has received a signal.SIGINT (2), so we stop receiving plots and terminate
            break
        # Retrieve data points that might have come in in the mean-time:
        client_messages = _queue_get_all_no_wait(main_input_queue,
                                                 max_plot_batch_size)
        # client_messages is a list of ClientMessage objects
        if len(client_messages) > 0:
            return_values = []
            with hold_dbplots():
                for client_msg in client_messages:  # For each ClientMessage object
                    # Take apart the received message, plot, and return the plot_id to the client who sent it
                    plot_message = pickle.loads(
                        client_msg.dbplot_message
                    )  # A DBPlotMessage object (see plotting_client.py)
                    plot_message.dbplot_args['draw_now'] = False
                    dbplot(**plot_message.dbplot_args)
                    return_values.append(
                        (client_msg.client_address, plot_message.plot_id))
            for client, plot_id in return_values:
                return_queue.put([client, plot_id])
        else:
            time.sleep(0.1)
def test_two_plots_in_the_same_axis_version_2():
    reset_dbplot()
    # Option 2: Give both plots the same 'axis' argument
    for i in xrange(5):
        data = np.random.randn(200)
        x = np.linspace(-5, 5, 100)
        with hold_dbplots():
            dbplot(data, 'histogram', plot_type='histogram', axis='hist')
            dbplot((x, 1./np.sqrt(2*np.pi*np.var(data)) * np.exp(-(x-np.mean(data))**2/(2*np.var(data)))), 'density', axis='hist', plot_type='line')
def test_two_plots_in_the_same_axis_version_1():
    reset_dbplot()
    # Option 1: Name the 'axis' argument to the second plot after the name of the first
    for i in xrange(5):
        data = np.random.randn(200)
        x = np.linspace(-5, 5, 100)
        with hold_dbplots():
            dbplot(data, 'histogram', plot_type='histogram')
            dbplot((x, 1./np.sqrt(2*np.pi*np.var(data)) * np.exp(-(x-np.mean(data))**2/(2*np.var(data)))), 'density', axis='histogram', plot_type='line')
def test_inline_custom_plots():

    for t in range(10):
        with hold_dbplots():
            x = np.sin(t / 10. + np.linspace(0, 10, 200))
            dbplot(x, 'x', plot_type='line')
            use_dbplot_axis('custom', clear=True)
            plt.plot(x, label='x', linewidth=2)
            plt.plot(x**2, label='$x**2$', linewidth=2)
Beispiel #15
0
def test_moving_point_multiple_points():

    for i in xrange(5):
        dbplot(np.sin([i / 10., i / 15.]),
               'unlim buffer',
               plot_type=MovingPointPlot)
        dbplot(np.sin([i / 10., i / 15.]),
               'lim buffer',
               plot_type=lambda: MovingPointPlot(buffer_len=20))
 def my_exp():
     for t in range(4):
         pts = np.linspace(0, 3 * (t + 1), 400)
         dbplot((pts * np.cos(pts), pts * np.sin(pts)),
                'plot',
                title='t={}'.format(t),
                plot_type='line')
         save_figure_in_record()
     plt.close(plt.gcf())
Beispiel #17
0
def test_plotting_server():

    if get_artemis_config_value(section='plotting', option='backend') != 'matplotlib-web':
        setup_web_plotting()

    for i in xrange(5):
        dbplot(np.random.randn(10, 10, 3), 'noise')
        dbplot(np.random.randn(20, 2), 'lines')
        plt.pause(0.1)
Beispiel #18
0
def test_close_and_open():

    for _ in xrange(20):
        dbplot(np.random.randn(5), 'a')

    plt.close(plt.gcf())

    for _ in xrange(20):
        dbplot(np.random.randn(5), 'b')
Beispiel #19
0
def test_moving_point_multiple_points():
    reset_dbplot()
    for i in xrange(5):
        dbplot(np.sin([i / 10., i / 15.]),
               'unlim buffer',
               plot_type=partial(MovingPointPlot))
        dbplot(np.sin([i / 10., i / 15.]),
               'lim buffer',
               plot_type=partial(MovingPointPlot, buffer_len=20))
Beispiel #20
0
def test_close_and_open():

    for _ in xrange(20):
        dbplot(np.random.randn(5), 'a')

    plt.close(plt.gcf())

    for _ in xrange(20):
        dbplot(np.random.randn(5), 'b')
Beispiel #21
0
def demo_rbm_tutorial(
        eta = 0.01,
        n_hidden = 500,
        n_samples = None,
        minibatch_size = 10,
        plot_interval = 10,
        w_init_mag = 0.01,
        n_epochs = 1,
        persistent = False,
        seed = None
        ):
    """
    This tutorial trains a standard binary-binary RBM on MNIST, and allows you to view the weights and negative sampling
    chain.

    Note:
    For simplicity, it uses hidden/visible samples to compute the gradient.  It's actually better to use the hidden
    probabilities.
    """
    if is_test_mode():
        n_samples=50
        n_epochs=1
        plot_interval=50
        n_hidden = 10

    data = get_mnist_dataset(flat = True).training_set.input[:n_samples]
    n_visible = data.shape[1]
    rng = np.random.RandomState(seed)
    activation = lambda x: (1./(1+np.exp(-x)) > rng.rand(*x.shape)).astype(float)

    w = w_init_mag*np.random.randn(n_visible, n_hidden)
    b_hid = np.zeros(n_hidden)
    b_vis = np.zeros(n_visible)

    if persistent:
        hid_sleep_state = np.random.rand(minibatch_size, n_hidden)

    for i, vis_wake_state in enumerate(minibatch_iterate(data, n_epochs = n_epochs, minibatch_size=minibatch_size)):
        hid_wake_state = activation(vis_wake_state.dot(w)+b_hid)
        if not persistent:
            hid_sleep_state = hid_wake_state
        vis_sleep_state = activation(hid_sleep_state.dot(w.T)+b_vis)
        hid_sleep_state = activation(vis_sleep_state.dot(w)+b_hid)

        # Update Parameters
        w_grad = (vis_wake_state.T.dot(hid_wake_state) - vis_sleep_state.T.dot(hid_sleep_state))/float(minibatch_size)
        w += w_grad * eta
        b_vis_grad = np.mean(vis_wake_state, axis = 0) - np.mean(vis_sleep_state, axis = 0)
        b_vis += b_vis_grad * eta
        b_hid_grad = np.mean(hid_wake_state, axis = 0) - np.mean(hid_sleep_state, axis = 0)
        b_hid += b_hid_grad * eta

        if i % plot_interval == 0:
            dbplot(w.T[:100].reshape(-1, 28, 28), 'weights')
            dbplot(vis_sleep_state.reshape(-1, 28, 28), 'dreams')
            print 'Sample %s' % i
Beispiel #22
0
def demo_debug_dbplot():

    import pdb
    for i in xrange(1000):
        dbplot(np.random.randn(50, 2), 'a')
        print('aaa')
        pdb.set_trace()
        dbplot(np.random.randn(10, 10), 'b')
        print('bbb')
        pdb.set_trace()
Beispiel #23
0
def test_smart_image_io(plot = False):

    image = smart_load('https://raw.githubusercontent.com/petered/data/master/images/artemis.jpeg', use_cache=True)
    smart_save(image[:, ::-1, :], 'output/simetra.png')
    rev_image = smart_load('output/simetra.png')
    assert np.array_equal(rev_image, rev_image)
    if plot:
        from artemis.plotting.db_plotting import dbplot
        dbplot(image, 'Artemis')
        dbplot(rev_image, 'Simetra', hang=True)
Beispiel #24
0
def demo_debug_dbplot():

    import pdb
    for i in xrange(1000):
        dbplot(np.random.randn(50, 2), 'a')
        print('aaa')
        pdb.set_trace()
        dbplot(np.random.randn(10, 10), 'b')
        print('bbb')
        pdb.set_trace()
def test_plotting_server():

    config = get_artemis_config()
    if config.get('plotting', 'backend') != 'matplotlib-web':
        setup_web_plotting()

    for i in xrange(5):
        dbplot(np.random.randn(10, 10, 3), 'noise')
        dbplot(np.random.randn(20, 2), 'lines')
        plt.pause(0.1)
Beispiel #26
0
def test_plotting_server():

    if get_artemis_config_value(section='plotting',
                                option='backend') != 'matplotlib-web':
        setup_web_plotting()

    for i in xrange(5):
        dbplot(np.random.randn(10, 10, 3), 'noise')
        dbplot(np.random.randn(20, 2), 'lines')
        plt.pause(0.1)
Beispiel #27
0
def test_same_object():
    """
    There was a bug where when you plotted two of the same array, you got "already seen object".  This tests makes
    sure it's gotten rid of.  If it's gone, both matrices should plot.  Otherwise you'll get "Already seen object" showing
    up on one of the plots.
    """
    reset_dbplot()
    a = np.random.randn(20, 20)
    for _ in xrange(5):
        dbplot(a, 'a')
        dbplot(a, 'b')
Beispiel #28
0
def test_dbplot(n_steps=3):

    arr = np.random.rand(10, 10)
    for i in xrange(n_steps):
        arr_sq = arr**2
        arr = arr_sq / np.mean(arr_sq)
        dbplot(arr, 'arr')
        for j in xrange(3):
            barr = np.random.randn(10, 2)
            dbplot(barr, 'barr', plot_type=lambda: LinePlot())
    clear_dbplot()
Beispiel #29
0
def test_same_object():
    """
    There was a bug where when you plotted two of the same array, you got "already seen object".  This tests makes
    sure it's gotten rid of.  If it's gone, both matrices should plot.  Otherwise you'll get "Already seen object" showing
    up on one of the plots.
    """
    reset_dbplot()
    a = np.random.randn(20, 20)
    for _ in xrange(5):
        dbplot(a, 'a')
        dbplot(a, 'b')
Beispiel #30
0
def demo_temporal_mnist(n_samples=None, smoothing_steps=200):
    _, _, original_data, original_labels = get_mnist_dataset(
        n_training_samples=n_samples, n_test_samples=n_samples).xyxy
    _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(
        n_training_samples=n_samples,
        n_test_samples=n_samples,
        smoothing_steps=smoothing_steps).xyxy
    for ox, oy, tx, ty in zip(original_data, original_labels, temporal_data,
                              temporal_labels):
        with hold_dbplots():
            dbplot(ox, 'sample', title=str(oy))
            dbplot(tx, 'smooth', title=str(ty))
Beispiel #31
0
def demo_gan_mnist(n_epochs=20,
                   minibatch_size=20,
                   n_discriminator_steps=1,
                   noise_dim=10,
                   plot_period=100,
                   rng=1234):
    """
    Train a Generative Adversarial network on MNIST data, showing generated samples as training progresses.

    :param n_epochs: Number of epochs to train
    :param minibatch_size: Size of minibatch to feed in each training iteration
    :param n_discriminator_steps: Number of steps training discriminator for every step of training generator
    :param noise_dim: Dimensionality of latent space (from which random samples are pulled)
    :param plot_period: Plot every N training iterations
    :param rng: Random number generator or seed
    """

    net = GenerativeAdversarialNetwork(
        discriminator=MultiLayerPerceptron.from_init(w_init=0.01,
                                                     layer_sizes=[784, 100, 1],
                                                     hidden_activation='relu',
                                                     output_activation='sig',
                                                     rng=rng),
        generator=MultiLayerPerceptron.from_init(
            w_init=0.1,
            layer_sizes=[noise_dim, 200, 784],
            hidden_activation='relu',
            output_activation='sig',
            rng=rng),
        noise_dim=noise_dim,
        optimizer=AdaMax(0.001),
        rng=rng)

    data = get_mnist_dataset(flat=True).training_set.input

    f_train_discriminator = net.train_discriminator.compile()
    f_train_generator = net.train_generator.compile()
    f_generate = net.generate.compile()

    for i, minibatch in enumerate(
            minibatch_iterate(data,
                              n_epochs=n_epochs,
                              minibatch_size=minibatch_size)):
        f_train_discriminator(minibatch)
        print 'Trained Discriminator'
        if i % n_discriminator_steps == n_discriminator_steps - 1:
            f_train_generator(n_samples=minibatch_size)
            print 'Trained Generator'
        if i % plot_period == 0:
            samples = f_generate(n_samples=minibatch_size)
            dbplot(minibatch.reshape(-1, 28, 28), "Real")
            dbplot(samples.reshape(-1, 28, 28), "Counterfeit")
            print 'Disp'
Beispiel #32
0
def classify(f, im_path):
    im = smart_load(im_path)
    print 'Processing image... "%s"' % (im_path, )
    inputs = im2vgginput(im)
    out = f(inputs)
    amax = np.argmax(out[0])
    label = get_vgg_label_at(amax)
    print 'Done.'
    dbplot(np.rollaxis(inputs[0], 0, 3)[..., ::-1],
           'Photo',
           title="{label}: {pct}%".format(label=label,
                                          pct=out[0, amax, 0, 0] * 100))
Beispiel #33
0
def test_dbplot(n_steps = 3):

    reset_dbplot()

    arr = np.random.rand(10, 10)
    for i in xrange(n_steps):
        arr_sq=arr**2
        arr = arr_sq/np.mean(arr_sq)
        dbplot(arr, 'arr')
        for j in xrange(3):
            barr = np.random.randn(10, 2)
            dbplot(barr, 'barr', plot_type=partial(LinePlot))
Beispiel #34
0
    def test_callback(info, score):
        if plot:
            dbplot(net.layers[0].w.get_value().T.reshape(-1, 28, 28),
                   'w0',
                   cornertext='Epoch {}'.format(info.epoch))
        if swap_mlp:
            all_layer_sizes = [dataset.input_size
                               ] + hidden_sizes + [dataset.target_size]
            fwd_ops = [
                info.sample * d1 * d2
                for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:])
            ]
            back_ops = [
                info.sample * d1 * d2
                for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:])
            ]
            update_ops = [
                info.sample * d1 * d2
                for d1, d2 in zip(all_layer_sizes[:-1], all_layer_sizes[1:])
            ]
        else:
            fwd_ops = [
                layer_.fwd_op_count.get_value() for layer_ in net.layers
            ]
            back_ops = [
                layer_.back_op_count.get_value() for layer_ in net.layers
            ]
            update_ops = [
                layer_.update_op_count.get_value() for layer_ in net.layers
            ]
        if info.epoch != 0:
            with IndentPrint('Mean Ops by epoch {}'.format(info.epoch)):
                print 'Fwd: {}'.format([
                    si_format(ops / info.epoch,
                              format_str='{value} {prefix}Ops')
                    for ops in fwd_ops
                ])
                print 'Back: {}'.format([
                    si_format(ops / info.epoch,
                              format_str='{value} {prefix}Ops')
                    for ops in back_ops
                ])
                print 'Update: {}'.format([
                    si_format(ops / info.epoch,
                              format_str='{value} {prefix}Ops')
                    for ops in update_ops
                ])
        if info.epoch > max(
                0.5, 2 * test_period) and not swap_mlp and score.get_score(
                    'train', 'noise_free') < 20:
            raise Exception("This horse ain't goin' nowhere.")

        op_count_info.append((info, (fwd_ops, back_ops, update_ops)))
Beispiel #35
0
def demo_pytorch_vae_mnist(hidden_sizes=[200, 200],
                           latent_dim=5,
                           distribution_type='bernoulli',
                           minibatch_size=20,
                           checkpoints=100,
                           n_epochs=20):

    cp = Checkpoints(checkpoints)

    model = VAEModel(
        encoder=make_mlp_encoder(visible_dim=784,
                                 hidden_sizes=hidden_sizes,
                                 latent_dim=latent_dim),
        decoder=make_mlp_decoder(latent_dim=latent_dim,
                                 hidden_sizes=hidden_sizes,
                                 visible_dim=784,
                                 dist_type=distribution_type),
        latent_dim=latent_dim,
    )
    # optimizer = Adam(params = model.parameters())
    # optimizer = RMSprop(params = model.parameters())
    # optimizer = Adamax(params = model.parameters())
    optimizer = Adagrad(params=model.parameters())
    # optimizer = SGD(lr=0.001, params = model.parameters())

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=minibatch_size,
                                               shuffle=True)

    for epoch in range(n_epochs):
        for batch_idx, (x, y) in enumerate(train_loader):

            epoch_pt = epoch + batch_idx / len(train_loader)

            optimizer.zero_grad()
            loss = -model.elbo(x.flatten(1)).sum()
            loss.backward()
            optimizer.step()

            rate = measure_global_rate('training')

            if cp():

                print(f'Mean Rate at Epoch {epoch_pt:.2g}: {rate:.3g}iter/s')
                z_samples = model.prior().sample((64, ))
                x_dist = model.decode(z_samples)
                dbplot(x_dist.mean.reshape(-1, 28, 28),
                       'Sample Means',
                       title=f'Sample Means at epoch {epoch_pt:.2g}')
Beispiel #36
0
def test_dbplot(n_steps=3):

    reset_dbplot()

    arr = np.random.rand(10, 10)
    for i in xrange(n_steps):
        arr_sq = arr**2
        arr = arr_sq / np.mean(arr_sq)
        dbplot(arr, 'arr')
        for j in xrange(3):
            barr = np.random.randn(10, 2)
            dbplot(barr, 'barr', plot_type=partial(LinePlot))
Beispiel #37
0
def test_smart_image_io(plot=False):

    image = smart_load(
        'https://raw.githubusercontent.com/petered/data/master/images/artemis.jpeg',
        use_cache=True)
    smart_save(image[:, ::-1, :], 'output/simetra.png')
    rev_image = smart_load('output/simetra.png')
    assert np.array_equal(rev_image, rev_image)
    if plot:
        from artemis.plotting.db_plotting import dbplot
        dbplot(image, 'Artemis')
        dbplot(rev_image, 'Simetra', hang=True)
Beispiel #38
0
def test_dbplot_logscale(n_steps = 3):
    reset_dbplot()

    arr = np.random.rand(10, 10)

    for i in xrange(n_steps):
        arr_sq=arr**2
        arr = arr_sq/np.mean(arr_sq)
        dbplot(arr, 'arr')
        for j in xrange(3):
            barr = np.random.randn(10, 2)
            kw = {"y_axis_type":"log"}
            dbplot(barr, 'barr', plot_type=partial(LinePlot,y_axis_type='log'))
Beispiel #39
0
def test_custom_axes_placement(hang=False):

    gs1 = gridspec.GridSpec(3, 1, left=0, right=0.5, hspace=0)
    dbplot(np.sin(np.linspace(0, 10, 100)), 'a', plot_type='line', axis=gs1[0, 0])
    dbplot(np.sin(np.linspace(0, 10, 100)+1), 'b', plot_type='line', axis=gs1[1, 0])
    dbplot(np.sin(np.linspace(0, 10, 100)+2), 'c', plot_type='line', axis=gs1[2, 0])

    gs2 = gridspec.GridSpec(2, 1, left=0.5, right=1, hspace=0.1)
    dbplot(np.random.randn(20, 20), 'im1', axis=gs2[0, 0])
    dbplot(np.random.randn(20, 20, 3), 'im2', axis=gs2[1, 0])

    if hang:
        dbplot_hang()
Beispiel #40
0
def test_multiple_figures():
    reset_dbplot()
    for _ in xrange(2):
        dbplot(np.random.randn(20, 20), 'a', fig='1')
        dbplot(np.random.randn(20, 20), 'b', fig='1')
        dbplot(np.random.randn(20, 20), 'c', fig='2')
        dbplot(np.random.randn(20, 20), 'd', fig='2')
Beispiel #41
0
def test_history_plot_updating():
    """
    This test checks that we've fixed the bug mentioned in issue 1: https://github.com/QUVA-Lab/artemis/issues/1
    That was, when you are updating multiple plots with history in a loop, everytime any of the plots is updated, they
    all get updated with the most recent data.  You'll see this in plot 'c' - with the bug, it moves in steps, with 3
    of the same sample in a row.  If it works it should be spikey.
    """
    reset_dbplot()
    for i in xrange(10):
        dbplot(np.random.randn(20, 20), 'a')
        dbplot(np.random.randn(20, 20), 'b')
        dbplot(np.random.randn(), 'c', plot_type=partial(MovingPointPlot))
Beispiel #42
0
def test_moving_point_multiple_points():
    reset_dbplot()
    for i in xrange(5):
        dbplot(np.sin([i/10., i/15.]), 'unlim buffer', plot_type = partial(MovingPointPlot))
        dbplot(np.sin([i/10., i/15.]), 'lim buffer', plot_type = partial(MovingPointPlot,buffer_len=20))
Beispiel #43
0
def test_cornertext():

    dbplot(np.random.randn(5, 5), 'a', cornertext='one')
    dbplot(np.random.randn(5, 5), 'a', cornertext='two')
    dbplot(np.random.randn(5, 5), 'a', cornertext='three')
Beispiel #44
0
def run_plotting_server(address, port, client_address, client_port):
    """
    Address and port to listen on.
    :param address:
    :param port:
    :return:
    """

    # Get the first available socket starting from portand communicate it with the client who started this server
    sock, port = get_socket(address=address, port=port)
    write_port_to_file(port)
    max_number_clients = 100
    max_plot_batch_size = 2000
    sock.listen(max_number_clients)
    one_time_send_to(address=client_address,port=client_port,message=str(port))

    # We want to save and rescue the current plot in case the plotting server receives a signal.SIGINT (2)
    killer = GracefulKiller()

    # The plotting server receives input in a queue and returns the plot_ids as a way of communicating that it has rendered the plot
    main_input_queue = Queue.Queue()
    return_queue = Queue.Queue()
    # Start accepting clients' communication requests
    t0 = threading.Thread(target=handle_socket_accepts,args=(sock, main_input_queue, return_queue, max_number_clients))
    t0.setDaemon(True)
    t0.start()


    # Received exp_dir on first db_plot_message?
    exp_dir_received = False
    # Now, we can accept plots in the main thread!

    set_dbplot_figure_size(9,10)
    while True:
        if killer.kill_now:
            sock.close() # will cause handle_socket_accepts thread to terminate
            # The server has received a signal.SIGINT (2), so we stop receiving plots and terminate
            break
        # Retrieve data points that might have come in in the mean-time:
        client_messages = _queue_get_all_no_wait(main_input_queue, max_plot_batch_size)
        # client_messages is a list of ClientMessage objects
        if len(client_messages) > 0:
            return_values = []
            with hold_dbplots():
                for client_msg in client_messages:  # For each ClientMessage object
                    # Take apart the received message, plot, and return the plot_id to the client who sent it
                    plot_message = pickle.loads(client_msg.dbplot_message)  # A DBPlotMessage object (see plotting_client.py)
                    plot_message.dbplot_args['draw_now'] = False

                    if not exp_dir_received:
                        if "exp_dir" == plot_message.dbplot_args["name"]:
                            atexit.register(save_current_figure,(plot_message.dbplot_args["data"]))
                            exp_dir_received = True
                            if len(client_messages) == 1:
                                continue
                            else:
                                continue
                    axis = dbplot(**plot_message.dbplot_args)
                    axis.ticklabel_format(style='sci', useOffset=False)
                    return_values.append((client_msg.client_address, plot_message.plot_id))
                if not exp_dir_received:
                    atexit.register(save_current_figure)
                    exp_dir_received = True
                plt.rcParams.update({'axes.titlesize': 'small', 'axes.labelsize': 'small'})
                plt.subplots_adjust(hspace=0.4,wspace=0.6)
            for client, plot_id in return_values:
                return_queue.put([client,plot_id])
        else:
            time.sleep(0.1)
Beispiel #45
0
    subpath = \
        'ILSVRC2015/Data/VID/snippets/test' if 'test' in identifier else \
        'ILSVRC2015/Data/VID/snippets/val' if 'val' in identifier else \
        'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0001/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0001/', identifier + '.mp4')) else \
        'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0002/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0002/', identifier + '.mp4')) else \
        'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0003/' if os.path.exists(os.path.join(archive_folder_path, 'ILSVRC2015/Data/VID/snippets/train/ILSVRC2015_VID_train_0003/', identifier + '.mp4')) else \
        bad_value(identifier, 'Could not find identifier: {}'.format(identifier, ))

    print('Loading %s' % (identifier, ))
    full_path = get_file_in_archive(
        relative_path='data/ILSVRC2015',
        subpath=os.path.join(subpath, identifier+'.mp4'),
        url='http://vision.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID_snippets_final.tar.gz'
        )
    video = smart_load_video(full_path, size=size, cut_edges=cut_edges, resize_mode=resize_mode, cut_edges_thresh=cut_edges_thresh)
    print('Done.')
    return video


if __name__ == '__main__':
    import itertools
    from artemis.plotting.db_plotting import dbplot, hold_dbplots

    identifiers = ['ILSVRC2015_train_00033009', 'ILSVRC2015_train_00033010', 'ILSVRC2015_train_00763000', 'ILSVRC2015_test_00004002']
    videos = [load_ilsvrc_video(identifier, size=(224, 224), cut_edges=True) for identifier in identifiers]

    for i in itertools.count(0):
        with hold_dbplots():
            for identifier, vid in zip(identifiers, videos):
                dbplot(vid[i%len(vid)], identifier, title='%s: %s' % (identifier, i%len(vid)))
Beispiel #46
0
def test_trajectory_plot():

    for i in xrange(5):
        dbplot((np.cos(i/10.), np.sin(i/11.)), 'path', plot_type='trajectory')
Beispiel #47
0
def test_individual_periodic_plotting():

    for t in range(100):
        dbplot(np.sin(t/10), 'sinusoid', draw_every='0.5s')
        dbplot(np.cos(t/10), 'cosinusoid', draw_every='1s')
        time.sleep(0.02)
def demo_plot_temporal_mnist(n_rows=8, n_cols=16, smoothing_steps=1000):
    _, _, temporal_data, temporal_labels = get_temporal_mnist_dataset(smoothing_steps=smoothing_steps).xyxy
    stride = len(temporal_data)/n_rows
    starts = np.arange(0, stride*n_rows, stride)
    data = np.array([temporal_data[s:s+n_cols] for s in starts]).swapaxes(0, 1)
    dbplot(data, 'Temporal MNIST', plot_type = 'pic', hang=True)
Beispiel #49
0
def test_particular_plot(n_steps = 3):
    reset_dbplot()

    for i in xrange(n_steps):
        r = np.random.randn(1)
        dbplot(r, plot_type=partial(HistogramPlot,edges=np.linspace(-5, 5, 20)))
Beispiel #50
0
def test_list_of_images():
    reset_dbplot()
    for _ in xrange(2):
        dbplot([np.random.randn(12, 30), np.random.randn(10, 10), np.random.randn(15, 10)])
Beispiel #51
0
 def random_walk():
     data = 0
     for i in xrange(10):
         data += np.random.randn()
         dbplot(data, 'walk')#, plot_type=lambda: MovingPointPlot(axes_update_mode='expand'))
Beispiel #52
0
def get_imagenet_images(indices):
    """
    Get imagenet images at the given indices
    :param indices:
    :return:
    """
    highest_index = np.max(indices)
    code_url_pairs = get_imagenet_fall11_urls(highest_index+1)
    files = [get_file('data/imagenet/%s%s' % (code_url_pairs[index][0], os.path.splitext(code_url_pairs[index][1])[1]), code_url_pairs[index][1]) for index in indices]
    return [smart_load(f) for f in files]


def get_imagenet_label_names():

    url = 'https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/596b27d23537e5a1b5751d2b0481ef172f58b539/imagenet1000_clsid_to_human.txt'
    with open(get_file('data/imagenet/labels.json', url=url)) as f:
        label_items = f.read()

    labels = [line[line.index(':')+1:].lstrip(' \'').rstrip('}, \'') for line in label_items.split('\n')]
    return labels

if __name__ == '__main__':
    # Downloads 4 random images out of the first 1000.  You may get 404 errors, etc.  So just run again and again til this works.
    import random
    from artemis.plotting.db_plotting import dbplot
    ixs = [random.randint(0, 999) for _ in xrange(4)]
    print(ixs)
    ims = get_imagenet_images(ixs)
    for i, (ix, im) in enumerate(zip(ixs, ims)):
        dbplot(im, 'Image %s' % i, hang = i==len(ims)-1)
Beispiel #53
0
def demo_dbplot(n_frames = 1000):
    """
    Demonstrates the various types of plots.

    The appropriate plot can usually be inferred from the first input data.  In cases where there are multple ways to
    display the input data, you can use the plot_type argument.
    """
    # Approximate frame rates:
    # Macbook Air, MacOSX backend, mode=safe ~2.4 FPS
    # Macbook Air, MacOSX backend, mode=fast:  FAILS
    # Macbook Air, TkAgg backend, mode=safe: ~1.48 FPS
    # Macbook Air, TkAgg backend, mode=fast: ~4.4 FPS
    # Linux box, Qt4Agg backend, mode=safe: ~2.5 FPS
    # Linux box, Qt4Agg backend, mode=fast: (plot does not update)

    # set_dbplot_figure_size(150, 100)
    for i in xrange(n_frames):
        t_start = time.time()
        with hold_dbplots():  # Sets it so that all plots update at once (rather than redrawing on each call, which is slower)
            dbplot(np.random.randn(20, 20), 'Greyscale Image')
            dbplot(np.random.randn(20, 20, 3), 'Colour Image')
            dbplot(np.random.randn(15, 20, 20), "Many Images")
            dbplot(np.random.randn(3, 6, 20, 20, 3), "Colour Image Grid")
            dbplot([np.random.randn(15, 20, 3), np.random.randn(10, 10, 3), np.random.randn(10, 30, 3)], "Differently Sized images")
            dbplot(np.random.randn(20, 2), 'Two Lines')
            dbplot((np.linspace(-5, 5, 100)+np.sin(np.linspace(-5, 5, 100)*2), np.linspace(-5, 5, 100)), '(X,Y) Lines')
            dbplot([np.sin(i/20.), np.sin(i/15.)], 'Moving Point History')
            dbplot([np.sin(i/20.), np.sin(i/15.)], 'Bounded memory history', plot_type=partial(MovingPointPlot, buffer_len=10))
            dbplot((np.cos(i/10.), np.sin(i/11.)), '(X,Y) Moving Point History')
            dbplot((np.cos(i/np.array([10., 15., 20.])), np.sin(i/np.array([11., 16., 21.])))+np.array([0, 1, 2]), 'multi (X,Y) Moving Point History', plot_type='trajectory')
            dbplot(np.random.randn(20), 'Vector History')
            dbplot(np.random.randn(50), 'Histogram', plot_type = 'histogram')
            dbplot(np.random.randn(50), 'Cumulative Histogram', plot_type = 'cumhist')
            dbplot(('Veni', 'Vidi', 'Vici')[i%3], 'text-history')
            dbplot(('Veni', 'Vidi', 'Vici')[i%3], 'text-notice', plot_type='notice')
        if i % 10 == 0:
            print('Frame Rate: {:3g}FPS'.format(1./(time.time() - t_start)))