Esempio n. 1
0
def initialize_visdom(visStartUpSec=5):
    vis = Visdom(server='http://localhost', port=8097)

    while not vis.check_connection() and visStartUpSec > 0.0:
        time.sleep(0.1)
        visStartUpSec -= 0.1
    assert vis.check_connection(), 'No connection could be formed quickly'

    print("VisdomLinePlotter initialized.")

    return vis
Esempio n. 2
0
 def testCallbackUpdateGraph(self):
     with subprocess.Popen(['python', '-m', 'visdom.server', '-port', str(self.port)]) as proc:
         # wait for visdom server startup (any better way?)
         viz = Visdom(server=self.host, port=self.port)
         for attempt in range(5):
             time.sleep(1.0)  # seconds
             if viz.check_connection():
                 break
         assert viz.check_connection()
         viz.close()
         self.model.update(self.corpus)
         proc.kill()
Esempio n. 3
0
def get_default_visdom_env():
    """
    Create and return default environment from visdom
    Returns
    -------
    Visdom
           Visdom class object
    """
    default_port = 8097
    default_hostname = "http://localhost"
    parser = argparse.ArgumentParser(description='Demo arguments')
    parser.add_argument('-port',
                        metavar='port',
                        type=int,
                        default=default_port,
                        help='port the visdom server is running on.')
    parser.add_argument(
        '-server',
        metavar='server',
        type=str,
        default=default_hostname,
        help='Server address of the target to run the demo on.')
    flags = parser.parse_args()
    viz = Visdom(port=flags.port, server=flags.server)

    assert viz.check_connection(timeout_seconds=3), \
        'No connection could be formed quickly'
    return viz
Esempio n. 4
0
def visdom_server():
    # Start Visdom server once and stop it with visdom_server_stop
    global vd_hostname, vd_port, vd_server_process

    if vd_server_process is None:

        import subprocess
        import time

        from visdom import Visdom
        from visdom.server import download_scripts

        download_scripts()

        vd_hostname = "localhost"
        vd_port = random.randint(8089, 8887)

        try:
            vis = Visdom(server=vd_hostname, port=vd_port, raise_exceptions=True)
        except ConnectionError:
            pass

        vd_server_process = subprocess.Popen(
            ["python", "-m", "visdom.server", "--hostname", vd_hostname, "-port", str(vd_port)]
        )
        time.sleep(5)

        vis = Visdom(server=vd_hostname, port=vd_port)
        assert vis.check_connection()
        vis.close()

    yield (vd_hostname, vd_port)
Esempio n. 5
0
def visualize(epoch, input):  # C,target,fin1,fin2,fin3):
    from visdom import Visdom
    viz = Visdom()
    assert viz.check_connection()

    # print(tar.shape)

    # for c in range(C):
    # print(fin1.cpu()[n,c][np.newaxis, :].shape)
    # viz.image(
    #     fin1.cpu()[0,c][np.newaxis, :],
    #     opts=dict(title='fin1', caption='fin1'),
    # )
    # viz.image(
    #     fin2.cpu()[0,c][np.newaxis, :],
    #     opts=dict(title='fin2', caption='fin2'),
    # )
    # viz.image(
    #     fin3.cpu()[0,c][np.newaxis, :],
    #     opts=dict(title='fin3', caption='fin3'),
    # )

    viz.heatmap(input[0, 0],
                opts=dict(colormap='Electric',
                          title='Epoch-{} input'.format(epoch)))
    # viz.heatmap(X=target[0, c],
    #             opts=dict(colormap='Electric', title='Epoch-{} Points-{} target'.format(epoch, c)))
    # viz.heatmap(X=fin1[0, c],
    #             opts=dict(colormap='Electric', title='Epoch-{} Points-{} fin1'.format(epoch, c)))
    # viz.heatmap(X=fin2[0, c],
    #             opts=dict(colormap='Electric', title='Epoch-{} Points-{} fin2'.format(epoch, c)))
    # viz.heatmap(X=fin3[0, c],
    #             opts=dict(colormap='Electric', title='Epoch-{} Points-{} fin3'.format(epoch, c)))
    return
Esempio n. 6
0
class VisdomWriter(object):
    def __init__(self, title, xlabel='Epoch', ylabel='Loss'):
        """Extended Visdom Writer"""
        self.vis = Visdom()
        assert self.vis.check_connection()
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.x = 0
        self.win = None

    def update_text(self, text):
        """Text Memo (usually used to note hyperparameter-configurations)"""
        self.vis.text(text)

    def update(self, y):
        """Update loss (X: Step (Epoch) / Y: loss)"""
        self.x += 1
        if self.win is None:
            self.win = self.vis.line(X=np.array([self.x]),
                                     Y=np.array([y]),
                                     opts=dict(
                                         title=self.title,
                                         xlabel=self.xlabel,
                                         ylabel=self.ylabel,
                                     ))
        else:
            self.vis.updateTrace(X=np.array([self.x]),
                                 Y=np.array([y]),
                                 win=self.win)
Esempio n. 7
0
class ZcsVisdom:
    def __init__(self, server='10.160.82.54', port=8899):
        self.vis = Visdom(server=server, port=port)
        assert self.vis.check_connection()

        self.wins = {}

    def plot(self, data, win):
        '''
        array: np.array
        win: str
        '''
        x = np.arange(data.shape[0])
        self.vis.line(data, x, win=win)#, update='append')

    def append(self, data, win, opts):
        '''
        data是个list,长度和绘制的折线数一致
        '''
        assert isinstance(data, list)
        if win not in self.wins.keys():
            self.vis.close(win)
            self.wins[win] = 0
        y = [data]
        x = [self.wins[win]]
        self.vis.line(Y=y, X=x, win=win, update='append', opts=opts)
        self.wins[win] += 1
Esempio n. 8
0
    def visualize(
        self,
        viz: Visdom,
        visdom_env_imgs: str,
        preds: Dict[str, Any],
        prefix: str,
    ) -> None:
        """
        Helper function to visualize the predictions generated
        in the forward pass.


        Args:
            viz: Visdom connection object
            visdom_env_imgs: name of visdom environment for the images.
            preds: predictions dict like returned by forward()
            prefix: prepended to the names of images
        """
        if not viz.check_connection():
            logger.info("no visdom server! -> skipping batch vis")
            return

        idx_image = 0
        title = f"{prefix}_im{idx_image}"

        vis_utils.visualize_basics(viz, preds, visdom_env_imgs, title=title)
Esempio n. 9
0
def server_updateTrace_plot(x: np.array, y: np.array, win=None, xlabel='iters', ylabel='loss', vis=None):
    '''
    第一次执行时,输入win=None,以后执行时,输入win为上一步的输出win
    e.g.
    win=server_updateTrace_plot(x=np.array([1,2,3]),y=np.array([321,453,542]))
    server_updateTrace_plot(x=np.array([4,5,6]),y=np.array([312,345,453]),win=win)
    :param x: np.array
    :param y: np.array
    :param win:
    :param xlabel:
    :param ylabel:
    :return:
    '''
    from visdom import Visdom
    if (vis == None):
        vis = Visdom()
        warnings.warn('may cause error when dealing with large dataset!')
    if (not vis.check_connection()):
        warnings.warn('excecute "python -m visdom.server"')
        ConnectionError
    else:
        if (win == None):
            win = vis.line(X=x.reshape(-1, 1), Y=y.reshape(-1, 1), opts=dict(
                xlabel=xlabel,
                ylabel=ylabel
            ))
        else:
            vis.line(X=x.reshape(-1, 1), Y=y.reshape(-1, 1), win=win, update='append')

    return win
Esempio n. 10
0
def main():
    viz = Visdom(server='http://192.168.1.108', port=8097, env='cifar100')
    assert viz.check_connection()

    train_set = ds.CIFAR100('./data', train=True, transform=transform, download=True)
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_OF_WORKERS)

    test_set = ds.CIFAR100('./data', train=False, transform=transform, download=True)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE * 2, shuffle=True, num_workers=NUM_OF_WORKERS)

    net = getDefaultAlexNet(INPUT_CHANNEL, NUM_CLASS, "dcm").cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    # ------- Save training data -------------------------------------------------------------
    train_file = open(os.path.join(SAVE_DIR, 'train.csv'), 'w')
    train_file.write('Epoch,Loss,Training Accuracy\n')
    test_file = open(os.path.join(SAVE_DIR, 'test.csv'), 'w')
    test_file.write('Epoch,Loss,Test Accuracy\n ')

    # Starting training process:
    train_loss_plot, test_loss_plot = np.array([]), np.array([])
    train_acc_plot, test_acc_plot = np.array([]), np.array([])

    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=0.1)
    for i in range(1, NUM_OF_EPOCHS + 1):
        print("---------------------------------------------------------------------------------------")
        train_loss, train_acc = train(i, net, optimizer, criterion, train_loader, train_file)
        train_loss_plot = np.append(train_loss_plot, train_loss)
        train_acc_plot = np.append(train_acc_plot, train_acc)

        test_loss, test_acc = test(i, net, criterion, test_loader, test_file)
        test_acc_plot = np.append(test_acc_plot, test_acc)
        test_loss_plot = np.append(test_loss_plot, test_loss)

        scheduler.step(epoch=i)
        print("---------------------------------------------------------------------------------------")

        if i % VISUAL_EVERY_EPOCH == 0:
            x_axis = range(1, i + 1)
            viz.line(
                Y=np.column_stack((train_loss_plot, test_loss_plot)),
                X=np.column_stack((x_axis, x_axis)),
                opts=dict(
                    title="loss plot at epoch (%d)" % i,
                    linecolor=np.row_stack((np.array(RGB_CYAN), np.array(RGB_MAGENTA))),
                    legend=["train loss", "test loss"]
                )
            )

            viz.line(
                Y=np.column_stack((train_acc_plot, test_acc_plot)),
                X=np.column_stack((x_axis, x_axis)),
                opts=dict(
                    title="accuracy plot at epoch (%d)" % i,
                    linecolor=np.row_stack((np.array(RGB_DODGERBLUE), np.array(RGB_FIREBRICK))),
                    legend=["train accuracy", "test accuracy"]
                )
            )
Esempio n. 11
0
class Logger():
    """Logger for training."""
    def __init__(self, enable_visdom=False, curve_names=None):
        self.curve_names = curve_names
        if enable_visdom:
            self.vis = Visdom()
            assert self.vis.check_connection()
            self.curve_x = np.array([0])
        else:
            self.curve_names = None

    def log(self, xval=None, win_name='loss', **kwargs):
        """Log and print the information."""
        print("##############################################################")
        for key, value in kwargs.items():
            print(key, value, sep='\t')

        if self.curve_names:
            if not xval:
                xval = self.curve_x
            for i in range(len(self.curve_names)):
                name = self.curve_names[i]
                if name not in kwargs:
                    continue
                yval = np.array([kwargs[name]])
                self.vis.line(Y=yval,
                              X=xval,
                              win=win_name,
                              update='append',
                              name=name,
                              opts=dict(showlegend=True))
                self.curve_x += 1

    def plot_curve(self, yvals, xvals, win_name='pr_curves'):
        """Plot curve."""
        self.vis.line(Y=np.array(yvals), X=np.array(xvals), win=win_name)

    def plot_marking_points(self, image, marking_points, win_name='mk_points'):
        """Plot marking points on visdom."""
        width, height = image.size
        draw = ImageDraw.Draw(image)
        for point in marking_points:
            p0_x = width * point.x
            p0_y = height * point.y
            p1_x = p0_x + 50 * math.cos(point.direction)
            p1_y = p0_y + 50 * math.sin(point.direction)
            draw.line((p0_x, p0_y, p1_x, p1_y), fill=(255, 0, 0))
            p2_x = p0_x - 50 * math.sin(point.direction)
            p2_y = p0_y + 50 * math.cos(point.direction)
            if point.shape > 0.5:
                draw.line((p2_x, p2_y, p0_x, p0_y), fill=(255, 0, 0))
            else:
                p3_x = p0_x + 50 * math.sin(point.direction)
                p3_y = p0_y - 50 * math.cos(point.direction)
                draw.line((p2_x, p2_y, p3_x, p3_y), fill=(255, 0, 0))
        image = np.asarray(image, dtype="uint8")
        image = np.transpose(image, (2, 0, 1))
        self.vis.image(image, win=win_name)
Esempio n. 12
0
def draw(items, labels):

	length = len(items)

	viz = Visdom()
	assert viz.check_connection()

	Y = labels
	freq = items[:, 4]
	norm_F = items[:, 5]

	freq = freq.reshape((length, 1))
	norm_F = norm_F.reshape((length, 1))

	print('[INFO] draw : frew.shape', freq.shape)
	print('[INFO] draw : norm_F.shape', norm_F.shape)
	X = np.concatenate([freq, norm_F], 1)
	print('[INFO] draw : X.shape', X.shape)

	# scatter = viz.scatter(

	#     X=np.random.rand(100, 2),
	#     Y=(Y[Y > 0] + 1.5).astype(int),
	#     opts=dict(
	#         legend=['Apples', 'Pears'],
	#         xtickmin=0,
	#         xtickmax=1,
	#         xtickstep=0.5,
	#         ytickmin=0,
	#         ytickmax=1,
	#         ytickstep=0.5,
	#         markersymbol='cross-thin-open',

	#     ),
	# )

	# viz.scatter(
	#     X=np.random.rand(255, 2),
	#     #随机指定1或者2
	#     Y=(np.random.rand(255) + 1.5).astype(int),
	#     opts=dict(
	#         markersize=10,
	#         ## 分配两种颜色
	#         markercolor=np.random.randint(0, 255, (2, 3,)),
	#     ),
	# )

	#3D 散点图
	viz.scatter(
	    X=X,
	    Y=Y,
	    opts=dict(
	        legend=['NEG', 'POS'],
	        markersize=5,
	    )
	)
Esempio n. 13
0
def visdom_plot(total_num_steps, cum_reward):
    from visdom import Visdom

    global vis
    global win1
    global win2
    global avg_reward

    if vis is None:
        vis = Visdom()
        assert vis.check_connection()

        # Close all existing plots
        vis.close()

    # Running average for curve smoothing
    # avg_reward = avg_reward * 0.9 + 0.1 * cum_reward
    avg_reward = cum_reward / total_num_steps

    X.append(total_num_steps)
    Y1.append(cum_reward)
    Y2.append(avg_reward)

    # The plot with the handle 'win' is updated each time this is called
    win1 = vis.line(
        X=np.array(X),
        Y=np.array(Y1),
        opts=dict(
            #title = 'All Environments',
            xlabel='Total number of steps',
            ylabel='Cumulative reward',
            ytickmin=0,
            #ytickmax=1,
            #ytickstep=0.1,
            #legend=legend,
            #showlegend=True,
            width=900,
            height=500),
        win=win1)

    # The plot with the handle 'win' is updated each time this is called
    win2 = vis.line(
        X=np.array(X),
        Y=np.array(Y2),
        opts=dict(
            # title = 'All Environments',
            xlabel='Total number of episodes',
            ylabel='Average Reward',
            ytickmin=0,
            # ytickmax=1,
            # ytickstep=0.1,
            # legend=legend,
            # showlegend=True,
            width=900,
            height=500),
        win=win2)
Esempio n. 14
0
def main(config):

    loss_weight = torch.ones(config.nb_classes)
    loss_weight[0] = 1.53297775619
    loss_weight[1] = 7.63194124408

    model = EDANet(config= config)
    print(model)
    
    # create visdom
    viz = Visdom(server=args.server, port=args.port, env=model.name)
    assert viz.check_connection(timeout_seconds=3), \
        'No connection could be formed quickly'
    
    train_dataset = MyDataset(config=config, subset='train')
    valid_dataset = MyDataset(config=config, subset='val')
    test_dataset = MyDataset(config=config, subset='test')

    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_size=config.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)
    # TODO will drop_last will have effects on accuracy? no
    valid_data_loader = DataLoader(dataset=valid_dataset,
                                   batch_size=config.batch_size,
                                   shuffle=False,
                                   num_workers=2,
                                   drop_last=True)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=config.batch_size,
                                  shuffle=False,
                                  num_workers=2,
                                  drop_last=True)

    begin_time = datetime.datetime.now().strftime('%m%d_%H%M%S')

    for_train(model = model, config=config,
              train_data_loader = train_data_loader,
              valid_data_loader= valid_data_loader,
              begin_time= begin_time,
              resume_file= None,
              loss_weight= loss_weight,
              visdom=viz)

    """
    # testing phase does not need visdom, just one scalar for loss, miou and accuracy
    """
    for_test(model = model, config=config,
             test_data_loader=test_data_loader,
             begin_time= begin_time,
             loss_weight= loss_weight,
             do_predict= True,)
class VisdomWebServer(object):
    def __init__(self):

        DEFAULT_PORT = 8097
        DEFAULT_HOSTNAME = "http://localhost"

        self.vis = Visdom(port=DEFAULT_PORT, server=DEFAULT_HOSTNAME)

    def update(self, metrics):

        if not self.vis.check_connection():
            'No connection could be formed quickly'
            return

        # Learning curve
        try:
            fig, ax = plt.subplots()
            plt.plot(metrics['train_loss'],
                     label='Training loss',
                     color='#32526e')
            plt.plot(metrics['val_loss'],
                     label='Validation loss',
                     color='#ff6b57')
            plt.legend()
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            plt.grid(zorder=0, color='lightgray', linestyle='--')
            self.vis.matplot(plt, win='lrcurve')
            plt.close()
            plt.clf()

            fig, ax = plt.subplots()
            plt.plot(metrics['learning_rate'], color='#32526e')
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            plt.grid(zorder=0, color='lightgray', linestyle='--')
            self.vis.matplot(plt, win='lr_rate')
            plt.close()
            plt.clf()

            #plt.figure()
            #plt.plot(metrics['zernike_train_loss'], label='Zernike train loss', color='blue')
            #plt.plot(metrics['zernike_val_loss'], label='Zernike val loss', color='red')
            #plt.legend()
            #plt.grid()
            #self.vis.matplot(plt, win='lrcurve_z')
            #plt.close()
            #plt.clf()
        except BaseException as err:
            print('Skipped matplotlib example')
            print('Error message: ', err)
Esempio n. 16
0
class VisTorch:

    DEFAULT_HOSTNAME = "127.0.0.1"
    DEFAULT_PORT = 8097
    VIS_CON = None

    def __init__(self, env_name=None):

        if env_name is None:
            env_name = str(datetime.now().strftime("%m%d%H%M%S"))
        self.env_name = env_name
        self.loss_window = None

    def close(self):
        self.VIS_CON.close()

    def __vis_initializer(self):

        if self.VIS_CON is None:
            self.VIS_CON = Visdom(server=self.DEFAULT_HOSTNAME,
                                  port=self.DEFAULT_PORT)

        assert self.VIS_CON.check_connection(
            timeout_seconds=3), 'No connection could be formed quickly'

    def plot_loss(self,
                  epoch,
                  *losses,
                  loss_type='Loss',
                  ytickmin=None,
                  ytickmax=None):

        self.__vis_initializer()
        legend = ['Training', 'Evaluation', 'Training_1']
        linecolors = np.array([[0, 191, 255], [255, 10, 0], [255, 0, 255]])
        self.loss_window = self.VIS_CON.line(
            Y=np.column_stack(losses),
            X=np.column_stack([epoch] * len(losses)),
            win=self.loss_window,
            update='append' if self.loss_window else None,
            opts={
                'xlabel': 'Epoch',
                'ylabel': loss_type,
                'ytickmin': ytickmin,
                'ytickmax': ytickmax,
                'title': 'Learning curve',
                'showlegend': True,
                'linecolor': linecolors[:len(losses)],
                'legend': legend[:len(losses)]
            })
Esempio n. 17
0
class VVisual():
    def __init__(self, env):
        self.viz = Visdom(env=env)
        assert self.viz.check_connection(), 'No Visdom Connection'

    def line(self, log, x):
        try:
            for k, v in log.items():
                self.viz.line(Y=np.array([v]),
                              X=np.array([x]),
                              win=k,
                              update='append' if x else None,
                              opts={'title': k})
        except BaseException as e:
            print('Visdom exception: {}'.format(repr(e)))
    def testCallbackUpdateGraph(self):

        # Popen have no context-manager in 2.7, for this reason - try/finally.
        try:
            # spawn visdom.server
            proc = subprocess.Popen(['python', '-m', 'visdom.server', '-port', str(self.port)])

            # wait for visdom server startup (any better way?)
            time.sleep(3)

            viz = Visdom(server=self.host, port=self.port)
            assert viz.check_connection()

            # clear screen
            viz.close()

            self.model.update(self.corpus)
        finally:
            proc.kill()
    def testCallbackUpdateGraph(self):

        # Popen have no context-manager in 2.7, for this reason - try/finally.
        try:
            # spawn visdom.server
            proc = subprocess.Popen(
                ['python', '-m', 'visdom.server', '-port',
                 str(self.port)])

            # wait for visdom server startup (any better way?)
            time.sleep(3)

            viz = Visdom(server=self.host, port=self.port)
            assert viz.check_connection()

            # clear screen
            viz.close()

            self.model.update(self.corpus)
        finally:
            proc.kill()
Esempio n. 20
0
def visdom_plot(total_num_steps, mean_reward):
    # Lazily import visdom so that people don't need to install visdom
    # if they're not actually using it
    from visdom import Visdom

    global vis
    global win
    global avg_reward

    if vis is None:
        vis = Visdom()
        assert vis.check_connection()

        # Close all existing plots
        vis.close()

    # Running average for curve smoothing
    avg_reward = avg_reward * 0.9 + 0.1 * mean_reward

    X.append(total_num_steps)
    Y.append(avg_reward)

    # The plot with the handle 'win' is updated each time this is called
    win = vis.line(
        X=np.array(X),
        Y=np.array(Y),
        opts=dict(
            #title = 'All Environments',
            xlabel='Total time steps',
            ylabel='Reward per episode',
            ytickmin=0,
            #ytickmax=1,
            #ytickstep=0.1,
            #legend=legend,
            #showlegend=True,
            width=900,
            height=500),
        win=win)
class visdom_recorder:
    def __init__(self):
        self.viz = Visdom()
        time.sleep(
            1
        )  #this is only necessary if your network is somehow slow like mine
        assert self.viz.check_connection(
        ), 'Visdom server connection failed, start server with python -m visdom.server and default port = 8097'
        self.windows = {}

    def add_scalar(self, tag, val, niter):
        val, niter = map(
            lambda x: np.ones([1]) * x,
            [val, niter])  #create array from scalar to make visdom happy
        if tag not in self.windows.keys():
            lineplot = self.viz.line(Y=val, X=niter, opts=dict(title=tag))
            self.windows[tag] = lineplot
        else:
            win = self.windows[tag]
            self.viz.line(
                Y=val,
                X=niter,
                win=win,
                update='append',
            )

    def add_image_grid(self, tag, x, niter):
        if tag not in self.windows.keys():
            grid = self.viz.images(x,
                                   opts=dict(
                                       title=tag,
                                       caption='At Iter {}'.format(niter)))
            self.windows[tag] = grid
        else:
            win = self.windows[tag]
            self.viz.images(x, win=win)
Esempio n. 22
0
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        assert (self.viz.check_connection())
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, title_name, x, y, update='append'):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array(x),
                                                 Y=np.array(y),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=title_name,
                                                           xlabel='X-axis',
                                                           ylabel=var_name))
        else:
            self.viz.line(X=np.array(x),
                          Y=np.array(y),
                          env=self.env,
                          win=self.plots[var_name],
                          name=split_name,
                          update=update)
Esempio n. 23
0
    def testCallbackUpdateGraph(self):

        # create a new process for visdom.server
        proc = subprocess.Popen(['python', '-m', 'visdom.server'])

        # wait for visdom server startup
        time.sleep(3)

        # create visdom object
        print(FLAGS.port)
        print(FLAGS.server)
        viz = Visdom(port=FLAGS.port, server=FLAGS.server)

        # check connection
        assert viz.check_connection()

        # clear screen
        viz.close()

        # test callback's update graph function
        try:
            self.model.update(self.corpus)
            # raise AttributeError("test")
        except Exception as e:
            print(e)
            # kill visdom.server
            proc.kill()
            print('server killed')
            self.assertTrue(0)

        # kill visdom.server
        try:
            proc.wait(timeout=3)
        except subprocess.TimeoutExpired:
            proc.kill()
            print('server killed')
Esempio n. 24
0
from __future__ import unicode_literals

from visdom import Visdom
import numpy as np
import math
import os.path
import getpass
import time
from sys import platform as _platform
from six.moves import urllib

try:
    viz = Visdom()

    startup_sec = 1
    while not viz.check_connection() and startup_sec > 0:
        time.sleep(0.1)
        startup_sec -= 0.1
    assert viz.check_connection(), 'No connection could be formed quickly'

    textwindow = viz.text('Hello World!')

    updatetextwindow = viz.text('Hello World! More text should be here')
    assert updatetextwindow is not None, 'Window was none'
    viz.text('And here it is', win=updatetextwindow, append=True)

    # text window with Callbacks
    txt = 'This is a write demo notepad. Type below. Delete clears text:<br>'
    callback_text_window = viz.text(txt)

    def type_callback(event):
Esempio n. 25
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from visdom import Visdom
import numpy as np
import math
import os.path
import getpass
from sys import platform as _platform
from six.moves import urllib

viz = Visdom()

assert viz.check_connection()

textwindow = viz.text('Hello World!')

updatetextwindow = viz.text('Hello World! More text should be here')
viz.text('And here it is', win=updatetextwindow, append=True)

# video demo:
try:
    video = np.empty([256, 250, 250, 3], dtype=np.uint8)
    for n in range(256):
        video[n, :, :, :].fill(n)
    viz.video(tensor=video)

    # video demo: download video from http://media.w3.org/2010/05/sintel/trailer.ogv
    video_url = 'http://media.w3.org/2010/05/sintel/trailer.ogv'
Esempio n. 26
0
def pretrain(config):

    image_datasets = {phase: ClsDataLoader(data_dir=config.data_dir,
                                           phase=phase,
                                           data_size=config.data_size)
                      for phase in ['train', 'eval']}

    print('loading dataset: train: {}, eval: {}'
          .format(len(image_datasets['train']),
                  len(image_datasets['eval'])))

    dataset_loaders = {'train': data.DataLoader(image_datasets['train'],
                                                batch_size=config.batch_size,
                                                shuffle=True,
                                                num_workers=4),

                       'eval': data.DataLoader(image_datasets['eval'],
                                               batch_size=config.batch_size,
                                               shuffle=False,
                                               num_workers=0)
                       }

    model = PretrainedModel(config.out_channels)

    if config.model_prefix > 0:
        model_file = os.path.join(config.out_dir, 'models', str(config.model_prefix)+'.pt')

        assert os.path.exists(model_file), \
            'pretrained model file ({}) does not exist, please check'.format(model_file)

        checkpoint = torch.load(model_file, map_location='cpu')
        model.load_state_dict(checkpoint['seg'], strict=False)
        opt1 = torch.optim.Adam(model.parameters(), lr=checkpoint['lr'])

        print('loading checkpoint from {}'.format(str(config.model_prefix)+'.pt'))
        print('loss: {}'.format(checkpoint['loss']))
    else:
        opt1 = torch.optim.Adam(model.parameters(), lr=config.lr)

    lr_scheduler_1 = torch.optim.lr_scheduler.ExponentialLR(opt1, gamma=0.99)

    CELoss = nn.CrossEntropyLoss()

    print('running on {}'.format(config.device))

    # set visdom
    if config.use_visdom:
        viz = Visdom()
        assert viz.check_connection()
        visline1 = viz.line(
            X=torch.Tensor([1]).cpu() * config.model_prefix,
            Y=torch.Tensor([0]).cpu(),
            win=1,
            opts=dict(xlabel='epochs',
                      ylabel='loss',
                      title='training loss',
                      )
        )
        visline2 = viz.line(
            X=torch.Tensor([1]).cpu() * config.model_prefix,
            Y=torch.Tensor([0]).cpu(),
            win=2,
            opts=dict(xlabel='epochs',
                      ylabel='loss',
                      title='evaluation loss',
                      )
        )
        visline3 = viz.line(
            X=torch.Tensor([1]).cpu() * config.model_prefix,
            Y=torch.Tensor([0]).cpu(),
            win=3,
            opts=dict(xlabel='epochs',
                      ylabel='LR',
                      title='Learning rate')
        )

    model = model.to(device=config.device)

    global_steps = {'train': 0, 'eval': 0}

    global min_loss
    min_loss = 1e5

    for epoch in range(config.model_prefix, config.epochs):
        lr_scheduler_1.step()
        for phase in ['train', 'eval']:
            running_loss = []

            if phase == 'train':
                model.train()
            else:
                model.eval()

            for i, (images, labels) in enumerate(dataset_loaders[phase]):

                start = time.time()

                images = images.to(device=config.device)
                labels = labels.to(device=config.device)

                opt1.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):

                    outputs = model(images)
                    loss = CELoss(outputs, labels)

                    running_loss.append(loss.item())

                    if phase == 'train':
                        loss.backward()
                        opt1.step()

                end = time.time()

                print('*'*20)
                print('epoch: {}/{}  {}_global_steps: {}  processing_time: {:.4f} s  LR: {:.8f}'.
                      format(epoch, config.epochs, phase, global_steps[phase],
                             end-start, opt1.param_groups[0]['lr']))
                print('{} loss: {:.6}'.format(phase, loss.item()))

                if phase == 'train' and i % 10 == 0:
                    logging.info('epoch:{} steps:{} processing_time:{:.4f}s LR:{:.8f} loss:{:.6}'.
                                 format(epoch, global_steps[phase], end-start,
                                        opt1.param_groups[0]['lr'], loss.item()))
                if phase == 'eval':
                    logging.info('eval_epoch:{} steps:{} processing_time:{:.4f}s LR:{:.8f} loss:{:.6}'.
                                 format(epoch, global_steps[phase], end - start,
                                        opt1.param_groups[0]['lr'], loss.item()))

                # set visdom
                if config.use_visdom and i % 5 == 0:
                    if phase == 'train':
                        viz.line(
                            X=torch.Tensor([1]).cpu() *
                              (epoch + i * config.batch_size / len(image_datasets[phase])),
                            Y=torch.Tensor([loss.item()]).cpu(),
                            win=visline1,
                            update='append'
                        )
                        viz.line(
                            X=torch.Tensor([1]).cpu() *
                              (epoch + i * config.batch_size / len(image_datasets[phase])),
                            Y=torch.Tensor([opt1.param_groups[0]['lr']]),
                            win=visline3,
                            update='append'
                        )
                    else:
                        viz.line(
                            X=torch.Tensor([1]).cpu() *
                              (epoch + i * config.batch_size / len(image_datasets[phase])),
                            Y=torch.Tensor([loss.item()]),
                            win=visline2,
                            update='append'
                        )

                global_steps[phase] += 1

            current_loss = sum(running_loss)/len(running_loss)

            if phase == 'train' and epoch % config.model_intervals == 0 and current_loss < min_loss:
                torch.save({
                            'epoch': epoch,
                            'seg': model.state_dict(),
                            # 'cls': classifier.state_dict(),
                            'lr1': opt1.param_groups[0]['lr'],
                            # 'lr2': opt2.param_groups[0]['lr'],
                            # 'loss': current_loss
                            },
                     os.path.join(config.out_dir, 'models', str(epoch)+'.pt')
                )
                running_loss.clear()
                min_loss = current_loss

                print('Saving model in {} for {} epoches'.format(config.out_dir +'/models', epoch))
Esempio n. 27
0
class VisdomWriter:
    def __init__(self):
        try:
            from visdom import Visdom
        except ImportError:
            raise ImportError("Visdom visualization requires installation of Visdom")

        self.scalar_dict = {}
        self.server_connected = False
        self.vis = Visdom()
        self.windows = {}

        self._try_connect()

    def _try_connect(self):
        startup_sec = 1
        self.server_connected = self.vis.check_connection()
        while not self.server_connected and startup_sec > 0:
            time.sleep(0.1)
            startup_sec -= 0.1
            self.server_connected = self.vis.check_connection()
        assert self.server_connected, 'No connection could be formed quickly'

    @_check_connection
    def add_scalar(self, tag, scalar_value, global_step=None, main_tag='default'):
        """Add scalar data to Visdom. Plots the values in a plot titled
           {main_tag}-{tag}.

        Args:
            tag (string): Data identifier
            scalar_value (float or string/blobname): Value to save
            global_step (int): Global step value to record
            main_tag (string): Data group identifier
        """
        if self.scalar_dict.get(main_tag) is None:
            self.scalar_dict[main_tag] = {}
        exists = self.scalar_dict[main_tag].get(tag) is not None
        self.scalar_dict[main_tag][tag] = self.scalar_dict[main_tag][tag] + [scalar_value] if exists else [scalar_value]
        plot_name = '{}-{}'.format(main_tag, tag)
        # If there is no global_step provided, follow sequential order
        x_val = len(self.scalar_dict[main_tag][tag]) if not global_step else global_step
        if exists:
            # Update our existing Visdom window
            self.vis.line(
                X=make_np(x_val),
                Y=make_np(scalar_value),
                name=plot_name,
                update='append',
                win=self.windows[plot_name],
            )
        else:
            # Save the window if we are creating this graph for the first time
            self.windows[plot_name] = self.vis.line(
                X=make_np(x_val),
                Y=make_np(scalar_value),
                name=plot_name,
                opts={
                    'title': plot_name,
                    'xlabel': 'timestep',
                    'ylabel': tag,
                },
            )

    @_check_connection
    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
        """Adds many scalar data to summary.

        Note that this function also keeps logged scalars in memory. In extreme case it explodes your RAM.

        Args:
            tag (string): Data identifier
            main_tag (string): Data group identifier
            tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
            global_step (int): Global step value to record

        Examples::

            writer.add_scalars('run_14h',{'xsinx':i*np.sin(i/r),
                                          'xcosx':i*np.cos(i/r),
                                          'arctanx': numsteps*np.arctan(i/r)}, i)
            This function adds three plots:
                'run_14h-xsinx',
                'run_14h-xcosx',
                'run_14h-arctanx'
            with the corresponding values.
        """
        for key in tag_scalar_dict.keys():
            self.add_scalar(key, tag_scalar_dict[key], global_step, main_tag)

    @_check_connection
    def export_scalars_to_json(self, path):
        """Exports to the given 'path' an ASCII file containing all the scalars written
        so far by this instance, with the following format:
        {writer_id : [[timestamp, step, value], ...], ...}

        The scalars saved by ``add_scalars()`` will be flushed after export.
        """
        with open(path, "w") as f:
            json.dump(self.scalar_dict, f)
        self.scalar_dict = {}

    @_check_connection
    def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
        """Add histogram to summary.

        Args:
            tag (string): Data identifier
            values (torch.Tensor, numpy.array, or string/blobname): Values to build histogram
            global_step (int): Global step value to record
            bins (string): one of {'tensorflow', 'auto', 'fd', ...}, this determines how the bins are made. You can find
              other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
        """
        values = make_np(values)
        self.vis.histogram(make_np(values), opts={'title': tag})

    @_check_connection
    def add_image(self, tag, img_tensor, global_step=None, caption=None):
        """Add image data to summary.

        Note that this requires the ``pillow`` package.

        Args:
            tag (string): Data identifier
            img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
            global_step (int): Global step value to record
        Shape:
            img_tensor: :math:`(C, H, W)`. Use ``torchvision.utils.make_grid()`` to prepare it is a good idea.
            C = colors (can be 1 - grayscale, 3 - RGB, 4 - RGBA)
        """
        img_tensor = make_np(img_tensor)
        self.vis.image(img_tensor, opts={'title': tag, 'caption': caption})

    @_check_connection
    def add_figure(self, tag, figure, global_step=None, close=True):
        """Render matplotlib figure into an image and add it to summary.

        Note that this requires the ``matplotlib`` package.

        Args:
            tag (string): Data identifier
            figure (matplotlib.pyplot.figure) or list of figures: figure or a list of figures
            global_step (int): Global step value to record
            close (bool): Flag to automatically close the figure
        """
        self.add_image(tag, figure_to_image(figure, close), global_step)

    @_check_connection
    def add_video(self, tag, vid_tensor, global_step=None, fps=4):
        """Add video data to summary.

        Note that this requires the ``moviepy`` package.

        Args:
            tag (string): Data identifier
            vid_tensor (torch.Tensor): Video data
            global_step (int): Global step value to record
            fps (float or int): Frames per second
        Shape:
            vid_tensor: :math:`(B, C, T, H, W)`. (if following tensorboard-pytorch format)
            vid_tensor: :math:`(T, H, W, C)`. (if following visdom format)
            B = batches, C = colors (1, 3, or 4), T = time frames, H = height, W = width
        """
        shape = vid_tensor.shape
        # A batch of videos (tensorboard-pytorch format) is a 5D tensor
        if len(shape) > 4:
            for i in range(shape[0]):
                # Reshape each video to Visdom's (T x H x W x C) and write each video
                if isinstance(vid_tensor, np.ndarray):
                    ind_vid = torch.from_numpy(vid_tensor[i, :, :, :, :]).permute(1, 2, 3, 0)
                else:
                    ind_vid = vid_tensor[i, :, :, :, :].permute(1, 2, 3, 0)
                scale_factor = 255 if np.any((ind_vid > 0) & (ind_vid < 1)) else 1
                # Visdom looks for .ndim attr, this is something raw Tensors don't have
                # Cast to Numpy array to get .ndim attr
                ind_vid = ind_vid.numpy()
                ind_vid = (ind_vid * scale_factor).astype(np.uint8)
                assert ind_vid.shape[3] in [1, 3, 4], \
                    'Visdom requires the last dimension to be color, which can be 1 (grayscale), 3 (RGB) or 4 (RGBA)'
                self.vis.video(tensor=ind_vid, opts={'fps': fps})
        else:
            self.vis.video(tensor=vid_tensor, opts={'fps': fps})

    @_check_connection
    def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100):
        """Add audio data to summary.

        Args:
            tag (string): Data identifier
            snd_tensor (torch.Tensor, numpy.array, or string/blobname): Sound data
            global_step (int): Global step value to record
            sample_rate (int): sample rate in Hz

        Shape:
            snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
        """
        snd_tensor = make_np(snd_tensor)
        self.vis.audio(tensor=snd_tensor, opts={'sample_frequency': sample_rate})

    @_check_connection
    def add_text(self, tag, text_string, global_step=None):
        """Add text data to summary.

        Args:
            tag (string): Data identifier
            text_string (string): String to save
            global_step (int): Global step value to record
        Examples::
            writer.add_text('lstm', 'This is an lstm', 0)
            writer.add_text('rnn', 'This is an rnn', 10)
        """
        if text_string is None:
            # Visdom doesn't support tags, write the tag as the text_string
            text_string = tag
        self.vis.text(text_string)

    @_check_connection
    def add_graph_onnx(self, prototxt):
        # TODO: Visdom doesn't support graph visualization yet, so this is a no-op
        return

    @_check_connection
    def add_graph(self, model, input_to_model=None, verbose=False, **kwargs):
        # TODO: Visdom doesn't support graph visualization yet, so this is a no-op
        return

    @_check_connection
    def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None):
        # TODO: Visdom doesn't support embeddings yet, so this is a no-op
        return

    @_check_connection
    def add_pr_curve(self, tag, labels, predictions, global_step=None, num_thresholds=127, weights=None):
        """Adds precision recall curve.

        Args:
            tag (string): Data identifier
            labels (torch.Tensor, numpy.array, or string/blobname): Ground truth data. Binary label for each element.
            predictions (torch.Tensor, numpy.array, or string/blobname):
            The probability that an element be classified as true. Value should in [0, 1]
            global_step (int): Global step value to record
            num_thresholds (int): Number of thresholds used to draw the curve.

        """
        labels, predictions = make_np(labels), make_np(predictions)
        raw_data = compute_curve(labels, predictions, num_thresholds, weights)

        # compute_curve returns np.stack((tp, fp, tn, fn, precision, recall))
        # We want to access 'precision' and 'recall'
        precision, recall = raw_data[4, :], raw_data[5, :]

        self.vis.line(
            X=recall,
            Y=precision,
            name=tag,
            opts={
                'title': 'PR Curve for {}'.format(tag),
                'xlabel': 'recall',
                'ylabel': 'precision',
            },
        )

    @_check_connection
    def add_pr_curve_raw(self, tag, true_positive_counts,
                         false_positive_counts,
                         true_negative_counts,
                         false_negative_counts,
                         precision,
                         recall, global_step=None, num_thresholds=127, weights=None):
        """Adds precision recall curve with raw data.

        Args:
            tag (string): Data identifier
            true_positive_counts (torch.Tensor, numpy.array, or string/blobname): true positive counts
            false_positive_counts (torch.Tensor, numpy.array, or string/blobname): false positive counts
            true_negative_counts (torch.Tensor, numpy.array, or string/blobname): true negative counts
            false_negative_counts (torch.Tensor, numpy.array, or string/blobname): false negative counts
            precision (torch.Tensor, numpy.array, or string/blobname): precision
            recall (torch.Tensor, numpy.array, or string/blobname): recall
            global_step (int): Global step value to record
            num_thresholds (int): Number of thresholds used to draw the curve.
            see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md
        """
        precision, recall = make_np(precision), make_np(recall)
        self.vis.line(
            X=recall,
            Y=precision,
            name=tag,
            opts={
                'title': 'PR Curve for {}'.format(tag),
                'xlabel': 'recall',
                'ylabel': 'precision',
            },
        )

    def close(self):
        del self.vis
        del self.scalar_dict
        gc.collect()
Esempio n. 28
0
parser.add_argument('-port',
                    metavar='port',
                    type=int,
                    default=DEFAULT_PORT,
                    help='port the visdom server is running on.')
parser.add_argument('-server',
                    metavar='server',
                    type=str,
                    default=DEFAULT_HOSTNAME,
                    help='Server address of the target to run the demo on.')
FLAGS = parser.parse_args()

try:
    viz = Visdom(port=FLAGS.port, server=FLAGS.server)

    assert viz.check_connection(timeout_seconds=3), \
        'No connection could be formed quickly'

    textwindow = viz.text('Hello World!')

    updatetextwindow = viz.text('Hello World! More text should be here')
    assert updatetextwindow is not None, 'Window was none'
    viz.text('And here it is', win=updatetextwindow, append=True)

    # text window with Callbacks
    txt = 'This is a write demo notepad. Type below. Delete clears text:<br>'
    callback_text_window = viz.text(txt)

    def type_callback(event):
        if event['event_type'] == 'KeyPress':
            curr_txt = event['pane_data']['content']
Esempio n. 29
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from visdom import Visdom
import numpy as np
import math
import os.path
import getpass
from sys import platform as _platform
from six.moves import urllib

viz = Visdom()

assert viz.check_connection()

textwindow = viz.text('Hello World!')

updatetextwindow = viz.text('Hello World! More text should be here')
assert updatetextwindow is not None, 'Window was none'
viz.text('And here it is', win=updatetextwindow, append=True)

# video demo:
try:
    video = np.empty([256, 250, 250, 3], dtype=np.uint8)
    for n in range(256):
        video[n, :, :, :].fill(n)
    viz.video(tensor=video)

    # video demo: download video from http://media.w3.org/2010/05/sintel/trailer.ogv
Esempio n. 30
0
 def check_visdom_works(self):
     viz = Visdom(server='http://'+self.defaults["server"], port=self.defaults["port"])
     try:
         assert (viz.check_connection())
     except:
         raise Exception("Error: Check Visdom Server Setup")
Esempio n. 31
0
def main(config, args):

    loss_weight = torch.ones(config.nb_classes)
    loss_weight[0] = 1.53297775619
    loss_weight[1] = 7.63194124408

    # Here config in model, only used for nb_classes, so we do not use args

    model = ESFNet.ESFNet(config=config)
    print(model)

    # create visdom
    viz = Visdom(server=args.server, port=args.port, env=model.name)
    assert viz.check_connection(timeout_seconds=3), \
        'No connection could be formed quickly'

    # TODO there are somewhat still need to change in ../configs/config.cfg
    train_dataset = MyDataset(config=config, args=args, subset='train')
    valid_dataset = MyDataset(config=config, args=args, subset='val')
    test_dataset = MyDataset(config=config, args=args, subset='test')

    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_size=config.batch_size,
                                   shuffle=True,
                                   pin_memory=True,
                                   num_workers=args.threads,
                                   drop_last=True)
    valid_data_loader = DataLoader(dataset=valid_dataset,
                                   batch_size=config.batch_size,
                                   shuffle=False,
                                   pin_memory=True,
                                   num_workers=args.threads,
                                   drop_last=True)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=config.batch_size,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=args.threads,
                                  drop_last=True)

    begin_time = datetime.datetime.now().strftime('%m%d_%H%M%S')

    for_train(model=model,
              config=config,
              args=args,
              train_data_loader=train_data_loader,
              valid_data_loader=valid_data_loader,
              begin_time=begin_time,
              resume_file=args.weight,
              loss_weight=loss_weight,
              visdom=viz)
    """
    # testing phase does not need visdom, just one scalar for loss, miou and accuracy
    """
    for_test(
        model=model,
        config=config,
        args=args,
        test_data_loader=test_data_loader,
        begin_time=begin_time,
        resume_file=args.weight,
        loss_weight=loss_weight,
    )
                        help='learning rate (default: 0.001)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0)')
    parser.add_argument('--save-model', action='store_true', default=True,
                        help='save the current Model')
    parser.add_argument('--save-directory', type=str, default='save_model',
                        help='learnt models are saving here')
    parser.add_argument('--class-num', type=int, default=62,
                        help='class num')
    parser.add_argument('--work', type=str, default='finetune',  # train, eval, finetune, predict
                        help='training, eval, predicting or finetuning')
    args = parser.parse_args()

    # visdom可视化设置
    vis = Visdom(env="traffic-sign-class 20200902")
    assert vis.check_connection()
    opts1 = {
        "title": 'loss of mean/max/min in epoch',
        "xlabel": 'epoch',
        "ylabel": 'loss',
        "width": 1000,
        "height": 400,
        "legend": ['train_mean_loss', 'train_max_loss', 'train_min_loss', 'test_mean_loss', 'test_max_loss',
                   'test_min_loss']
    }
    opts2 = {
        "title": 'precision recall with epoch',
        "xlabel": 'epoch',
        "ylabel": 'precision/recall',
        "width": 1000,
        "height": 400,
Esempio n. 33
0
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from visdom import Visdom
import time
import numpy as np

try:
    viz = Visdom()

    startup_sec = 1
    while not viz.check_connection() and startup_sec > 0:
        time.sleep(0.1)
        startup_sec -= 0.1
    assert viz.check_connection(), 'No connection could be formed quickly'

    # image callback demo
    def show_color_image_window(color, win=None):
        image = np.full([3, 256, 256], color, dtype=float)
        return viz.image(
            image,
            opts=dict(title='Colors', caption='Press arrows to alter color.'),
            win=win
        )

    image_color = 0
    callback_image_window = show_color_image_window(image_color)