コード例 #1
0
class VisdomPlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='gan'):
        self.viz = Visdom(server="http://0.0.0.0", port="800")
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, x, y, xlabel='epoch'):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel=xlabel,
                                                           ylabel=var_name))
        else:
            self.viz.line(X=np.array([x]),
                          Y=np.array([y]),
                          update="append",
                          env=self.env,
                          win=self.plots[var_name],
                          name=split_name)

    def draw(self, var_name, images):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.images(images, env=self.env)
        else:
            self.viz.images(images, env=self.env, win=self.plots[var_name])
コード例 #2
0
def main():
    cnn = CNN()
    cnn.load_state_dict(torch.load('model.pkl'))
    print("load cnn net.")

    predict_dataloader = my_dataset.get_predict_data_loader()

    vis = Visdom()
    for i, (images, labels) in enumerate(predict_dataloader):
        image = images
        vimage = Variable(image)
        predict_label = cnn(vimage)

        c0 = captcha_setting.ALL_CHAR_SET[np.argmax(
            predict_label[0, 0:captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
        c1 = captcha_setting.ALL_CHAR_SET[np.argmax(
            predict_label[0, captcha_setting.ALL_CHAR_SET_LEN:2 *
                          captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
        c2 = captcha_setting.ALL_CHAR_SET[np.argmax(
            predict_label[0, 2 * captcha_setting.ALL_CHAR_SET_LEN:3 *
                          captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
        c3 = captcha_setting.ALL_CHAR_SET[np.argmax(
            predict_label[0, 3 * captcha_setting.ALL_CHAR_SET_LEN:4 *
                          captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]

        c = '%s%s%s%s' % (c0, c1, c2, c3)
        print(c)
        vis.images(image, opts=dict(caption=c))
コード例 #3
0
class SampleImage(object):
    """ Make a grid and plot a class of image on one row """
    def __init__(self, n_row, n_col):
        self.n_row = n_row
        self.n_col = n_col
        self.viz = Visdom()
        self.fixed_z = None
        self.win = None

    def __call__(self, trainer, gan_model):
        if trainer:
            global_step = trainer.global_step
        else:
            global_step = 0
        if self.fixed_z is None:
            self.fixed_z = gan_model.sample_latent_code(self.n_row *
                                                        self.n_col)

        output_images = (gan_model.generate('fixed', self.fixed_z) + 1.) / 2.

        if self.win is None:
            self.win = self.viz.images(
                output_images,
                nrow=self.n_row,
                opts=dict(caption='Step: {}'.format(global_step)))
        else:
            self.viz.images(output_images,
                            nrow=self.n_row,
                            win=self.win,
                            opts=dict(caption='Step: {}'.format(global_step)))
class VisdomPlotter(object):
    """Plots to Visdom Server"""
    def __init__(self, env_name='gan'):
        """Initilized visdom, environment and plots dictionary"""
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, x, y, xlabel='epoch'):
        """ Plots graphs in visdom server. Usually generator/discrimantor loss values """
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel=xlabel,
                                                           ylabel=var_name))
        else:
            self.viz.updateTrace(X=np.array([x]),
                                 Y=np.array([y]),
                                 env=self.env,
                                 win=self.plots[var_name],
                                 name=split_name)

    def draw(self, var_name, images):
        """ Draws images in visdom server  """
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.images(images, env=self.env)
        else:
            self.viz.images(images, env=self.env, win=self.plots[var_name])
コード例 #5
0
ファイル: data.py プロジェクト: kylehkhsu/neural-tangents
    def test_omniglot():
        viz = Visdom(port=8000, env='main')
        splits = load_omniglot()

        n_way, n_support, n_query = 3, 5, 7
        # task = omniglot_task(splits['train'], n_way=3, n_support=5, n_query=7)

        batch_size = 2
        for i, batch in enumerate(
                taskbatch(omniglot_task,
                          batch_size=batch_size,
                          n_task=batch_size,
                          split_dict=splits['val'],
                          n_way=n_way,
                          n_support=n_support,
                          n_query=n_query)):

            for i_task in range(batch_size):
                x_train = batch['x_train'][i_task]
                x_test = batch['x_test'][i_task]
                y_train = batch['y_train'][i_task]
                y_test = batch['y_test'][i_task]

                viz.images(tensor=np.transpose(x_train, (0, 3, 1, 2)),
                           nrow=n_support)
                viz.text(f'y_train: {y_train}')
                viz.images(tensor=np.transpose(x_test, (0, 3, 1, 2)),
                           nrow=n_query)
                viz.text(f'y_test: {y_test}')

        viz.save(viz.get_env_list())
コード例 #6
0
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, title_name, x, y):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=title_name,
                                                           xlabel='Epochs',
                                                           ylabel=var_name))
        else:
            self.viz.line(X=np.array([x]),
                          Y=np.array([y]),
                          env=self.env,
                          win=self.plots[var_name],
                          name=split_name,
                          update='append')

    def images(self, data, win, opts):
        self.viz.images(data, win=win, env=self.env, opts=opts)

    def text(self, pred, win, opts):
        self.viz.text(pred, win=win, opts=opts, env=self.env)
コード例 #7
0
class VisdomPlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='gan'):
        self.viz = Visdom(port=7777)
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, x, y, xlabel='epoch'):
        x = list(range(y.shape[0]))
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.asarray([x]),
                                                 Y=np.asarray([y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel=xlabel,
                                                           ylabel=var_name))
        else:
            self.viz.line(X=np.asarray([x]),
                          Y=np.asarray([y]),
                          env=self.env,
                          win=self.plots[var_name],
                          name=split_name)

    def draw(self, var_name, images):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.images(images, env=self.env)
        else:
            self.viz.images(images, env=self.env, win=self.plots[var_name])
コード例 #8
0
class VisdomPlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}
        self.paramList = {}
    def argsTile(self, argsDict):
        self.paramList = self.viz.text('<b>Training Parameters:</b>\n', env=self.env, opts=dict(width=220,height=320))
        for key, value in argsDict.items():
            self.viz.text(str(key) + ' = ' + str(value) + '\n', env=self.env, win=self.paramList, append=True)
    def plot(self, var_name, split_name, x, y):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env,
                opts=dict(legend=[split_name], title=var_name, xlabel='Epochs', ylabel=var_name))
        else:
            self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update='append')
    def showImage(self, imageTensor):
        # self.viz.image(imageTensor, win=self.images, env=self.env, opts=dict(title='Original and Reconstructed', caption='How random.'),)
        self.viz.images(imageTensor, win=self.images, env=self.env, opts=dict(title='Original and Reconstructed', caption='How random.', nrow=2),)
    def plotPerformance(self, var_name, split_name, x, y):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=x, Y=y, env=self.env,
                opts=dict(legend=[split_name], title=var_name, xlabel='Epochs', ylabel=var_name))
        else:
            self.viz.line(X=x, Y=y, env=self.env, win=self.plots[var_name], name=split_name, update='append')
コード例 #9
0
class VisdomPlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='relational_net'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, x, y, xlabel='iteration'):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel=xlabel,
                                                           ylabel=var_name))
        else:
            self.viz.updateTrace(X=np.array([x]),
                                 Y=np.array([y]),
                                 env=self.env,
                                 win=self.plots[var_name],
                                 name=split_name)

    def draw(self, var_name, images):
        images = (images + 0.5) * 255
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.images(images, env=self.env)
        else:
            self.viz.images(images, env=self.env, win=self.plots[var_name])

    def print(self, var_name, text):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.text(text, env=self.env)
        else:
            self.viz.text(text, env=self.env, win=self.plots[var_name])
class VisdomPlotter(object):
    def __init__(self, env_name='gan'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, x, y, xlabel='epoch'):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel=xlabel,
                                                           ylabel=var_name))
        else:
            self.viz.updateTrace(X=np.array([x]),
                                 Y=np.array([y]),
                                 env=self.env,
                                 win=self.plots[var_name],
                                 name=split_name)

    def draw(self, var_name, images):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.images(images, env=self.env)
        else:
            self.viz.images(images, env=self.env, win=self.plots[var_name])
コード例 #11
0
class SampleImage(object):
    """ Make a grid and plot a class of image on one row """
    def __init__(self, n_row, nb_classes, code_size):
        self.n_row = n_row
        self.viz = Visdom()
        self.fixed_z = np.random.normal(0, 1, (n_row * nb_classes, code_size))
        self.labels = np.array(
            [num for _ in range(self.n_row) for num in range(nb_classes)])
        self.win = None

    def __call__(self, trainer, gan_model):
        global_step = trainer.global_step
        output_images = (gan_model.generate(
            'fixed', self.fixed_z, torch.from_numpy(self.labels)) + 1.) / 2.

        if self.win is None:
            self.win = self.viz.images(
                output_images,
                nrow=self.n_row,
                opts=dict(caption='Step: {}'.format(global_step)))
        else:
            self.viz.images(output_images,
                            nrow=self.n_row,
                            win=self.win,
                            opts=dict(caption='Step: {}'.format(global_step)))
コード例 #12
0
class Visualizer:
    def __init__(self):
        self.vis = Visdom()

        self.real_image = None
        self.real_heatmap = None
        self.real_concat = None
        self.real_batch = None
        self.fake_image = None
        self.fake_concat = None
        self.fake_heatmap = None
        self.fake_batch = None

        self.loss = None

        self.point = 0

    def initiate_windows(self):
        random_image = torch.rand(1, 256, 256)
        random_batch = torch.rand(8, 1, 256, 256)
        self.real_image = self.vis.image(random_image,
                                         win="real_img",
                                         opts={"caption": "Real image"})
        self.real_heatmap = self.vis.image(random_image, win="real_map")
        self.real_concat = self.vis.image(random_image, win="real_cat")
        self.real_batch = self.vis.images(random_batch, win="real_batch")
        self.fake_image = self.vis.image(random_image, win="fake_img")
        self.fake_concat = self.vis.image(random_image, win="fake_cat")
        self.fake_heatmap = self.vis.image(random_image, win="fake_map")
        self.fake_batch = self.vis.images(random_batch, win="fake_batch")

        self.loss = self.vis.line(X=torch.zeros(1),
                                  Y=torch.zeros(1, 2),
                                  win="loss",
                                  opts={
                                      "xlabel": "Iteration",
                                      "ylabel": "Loss",
                                      "title": "Training progression",
                                      "legend": ["pred real", "pred fake"],
                                  })
        self.point = 0

    def update_image(self, img, name):
        self.vis.image(img, win=name, opts={"caption": name})

    def update_batch(self, batch, name):
        self.vis.images(batch, win=name, opts={"caption": name}, nrow=8)

    def update_loss(self, pred_real, pred_fake):
        self.vis.line(X=torch.ones(1, 2) * self.point,
                      Y=torch.stack([pred_real, pred_fake], dim=1),
                      win="loss",
                      update="append",
                      opts={
                          "xlabel": "Iteration",
                          "ylabel": "Loss",
                          "title": "Here is a title",
                          "legend": ["pred real", "pred fake"],
                      })
        self.point += 1
コード例 #13
0
class VisdomLogger:
    def __init__(self, visdom_env='main', log_every=10, prefix=''):
        self.vis = None
        self.log_every = log_every
        self.prefix = prefix
        if visdom_env is not None:
            self.vis = Visdom(env=visdom_env)
            self.vis.close()

    def on_batch_end(self, state):
        iters = state['iters']
        if self.log_every != -1 and iters % self.log_every == 0:
            self.log(iters, state['metrics'])

    def on_epoch_end(self, state):
        self.log(state['iters'], state['metrics'])

    def log(self, iters, xs, store_history=[]):
        if self.vis is None:
            return

        for name, x in xs.items():
            name = self.prefix + name
            if isinstance(x, (float, int)):
                self.vis.line(X=[iters],
                              Y=[x],
                              update='append',
                              win=name,
                              opts=dict(title=name),
                              name=name)
            elif isinstance(x, str):
                self.vis.text(x, win=name, opts=dict(title=name))
            elif isinstance(x, torch.Tensor):
                if x.numel() == 1:
                    self.vis.line(X=[iters],
                                  Y=[x.item()],
                                  update='append',
                                  win=name,
                                  opts=dict(title=name),
                                  name=name)
                elif x.dim() == 2:
                    self.vis.heatmap(x, win=name, opts=dict(title=name))
                elif x.dim() == 3:
                    self.vis.image(x,
                                   win=name,
                                   opts=dict(title=name,
                                             store_history=name
                                             in store_history))
                elif x.dim() == 4:
                    self.vis.images(x,
                                    win=name,
                                    opts=dict(title=name,
                                              store_history=name
                                              in store_history))
                else:
                    assert False, "incorrect tensor dim"
            else:
                assert False, "incorrect type " + x.__class__.__name__
コード例 #14
0
ファイル: visualize.py プロジェクト: saeedizadi/SR_CLE
class Dashboard():
    def __init__(self, port=8097):
        self.vis = Visdom(port=port)

    def grid_plot(self, images, nrow):
        self.vis.images(images,
                        nrow=nrow,
                        padding=10,
                        opts=dict(title='Results'))
コード例 #15
0
class Visualizer:
    def __init__(self, env='default', **kwargs):
        self.vis = Visdom(env=env, **kwargs)
        self.index = {}
        self.log = ''

    def reinit(self, env='default', **kwargs):
        #  重新配置visdom
        self.vis = Visdom(env=env, **kwargs)
        return self

    def plot_many(self, d):
        for k, v in d.items():
            self.plot(k, v)

    def img_many(self, d):
        for k, v in d.items():
            self.img(k, v)

    def plot(self, name, y):
        x = self.index.get(name, 0)
        self.vis.line(X=np.array([x]),
                      Y=np.array([y]),
                      win=name,
                      opts=dict(title=name),
                      update=None if x == 0 else 'append')
        self.index[name] = x + 1

    def img(self, name, img_):
        if len(img_.size()) < 3:
            img_ = img_.cpu().unsqueeze(0)
        self.vis.images(tensor=img_.cpu(), win=name, opts=dict(title=name))

    def img_grid_many(self, d):
        for k, v in d.items():
            self.img_grid(k, v)

    def img_grid(self, name, input_3d):
        """
        一个batch的图片转成一个网格图,i.e. input(36,64,64)
        会变成 6*6 的网格图,每个格子大小64*64
        """
        self.img(
            name=name,
            img_=tv.utils.make_grid(
                tensor=input_3d.cpu()[0].unsequeeze(1).clamp(max=1, min=0)))

    def log(self, info, win='log'):
        self.log += '[{time}] {info} <br>'.format(
            time=time.strftime('%m%d_%H%M%S'), info=info)
        self.vis.text(text=self.log, win=win)

    def __getattr__(self, name):
        return getattr(self.vis, name)
コード例 #16
0
    def test_03(self):
        from visdom import Visdom
        viz = Visdom()

        # 单张
        viz.image(
            np.random.rand(3, 512, 256),
            opts=dict(title='Random!', caption='How random.'),
        )
        # 多张
        viz.images(np.random.randn(20, 3, 64, 64),
                   opts=dict(title='Random images', caption='How random'))
コード例 #17
0
ファイル: utils.py プロジェクト: beaupranisaa/RtmlLabs
class VisdomImage(object):
    """Show images on Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name

    def display_image(self, image, win, title_name):
        #self.viz.images(image, env=self.env, win = self.win, opts = dict(caption = title_name, width = 256, height = 256))
        self.viz.images(image,
                        env=self.env,
                        win=win,
                        opts=dict(caption=title_name))
コード例 #18
0
class Logger():
    def __init__(self, env):
        self.viz = Visdom(env=env)
        self.losses = {}
        self.loss_windows = {}
        self.image_windows = {}

    def log(self,
            epoch=None,
            losses=None,
            images=None,
            image_grid=None,
            env=None):
        # Draw images
        if images:
            for image_name, tensor in images.items():
                # pdb.set_trace()
                if image_name not in self.image_windows:
                    self.image_windows[image_name] = self.viz.image(
                        masktensor2image(tensor.data),
                        opts={'title': image_name})
                else:
                    self.viz.image(masktensor2image(tensor.data),
                                   win=self.image_windows[image_name],
                                   opts={'title': image_name})
        if image_grid:
            for image_name, tensor in image_grid.items():
                if image_name not in self.image_windows:
                    self.image_windows[image_name] = self.viz.images(
                        tensor, env=env, opts={'title': image_name})
                else:
                    self.viz.images(tensor,
                                    win=self.image_windows[image_name],
                                    env=env,
                                    opts={'title': image_name})

        # Plot losses
        if losses:
            for loss_name, loss in losses.items():
                if loss_name not in self.loss_windows:
                    self.loss_windows[loss_name] = self.viz.line(
                        X=np.array([epoch]),
                        Y=np.array([loss]),
                        opts={
                            'xlabel': 'epochs',
                            'ylabel': loss_name,
                            'title': loss_name
                        })
                else:
                    self.viz.line(X=np.array([epoch]),
                                  Y=np.array([loss]),
                                  win=self.loss_windows[loss_name],
                                  update='append')
コード例 #19
0
def main():
    mod=Mnist()
    optimizer=optim.SGD(mod.parameters(),lr=learning_rate)
    #当你想恢复某一阶段的训练(或者进行测试)时,那么就可以读取之前保存的网络模型参数等。
    checkpoint = torch.load('minst.pth')
    mod.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    loss_fun=nn.CrossEntropyLoss()
    vis=Visdom()
    vis.line([0.],[0.],win='train_loss',opts=dict(title='trai_loss'))
    vis.line([0.],[0.],win='accuracy',opts=dict(title='acc'))
    # vis.line([0.],[0.], win='val_loss', opts=dict(title='val_loss'))
    correct=0
    total_num=0
    global_step=0
    for epoch in range(start_epoch,3):
        for batch_index,(x,y) in enumerate(train_loader):
            # x=x.view(-1,28*28)
            logits=mod(x)
            train_loss=loss_fun(logits,y)
            pred=logits.argmax(dim=1)
            correct+=torch.eq(y,pred).float().sum()
            total_num += x.size(0)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            global_step+=1
            acc=100.*correct/total_num
            #vis.line([train_loss.item()],[global_step],win='train_loss',update='append')
            #vis.line([acc],[global_step],win='accuracy',update='append')
            print('the loss of {:d} step is {:.3f},the accuracy is {:.3f}%'.format(global_step,train_loss.item(),acc))
        #https://www.zhihu.com/question/363144860/answer/951669576(预测时必须使用)
        mod.eval()
        with torch.no_grad():
            val_correct=0
            val_total=0
            for validation_images,validation_label in validation_loader:
                # validation_images=validation_images.view(-1,28*28)
                val_logits=mod(validation_images)
                pred=val_logits.argmax(dim=1)
                val_loss=loss_fun(val_logits,validation_label)
                val_correct+=torch.eq(pred,validation_label).float().sum()
                val_total+=validation_images.size(0)
            # vis.line([val_loss.item()],[global_step],win='val_loss',update='append')
            vis.images(validation_images.view(-1,1,28,28),win='x')
            vis.text(str(pred.detach().cpu().numpy()), win='pred',
                     opts=dict(title='pred'))
            val_acc=100.* val_correct/val_total
            print('the val acc of {:d} epoch is {:.3f}%'.format(epoch,val_acc))
コード例 #20
0
class VisdomImagePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}

    def plot(self, image, title, nrow=8):
        self.viz.images(
            image,
            env=self.env,
            opts=dict(title=title),
            nrow=nrow
        )
コード例 #21
0
ファイル: utils.py プロジェクト: qwerty1917/radar-ewc
class VisdomPlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name: str, port: int):
        self.viz = Visdom(port=port)
        self.env = env_name
        self.line_plots = {}
        self.image_frames = {}

    def plot(self, var_name, split_name, title_name, x, y):
        if var_name not in self.line_plots:
            self.line_plots[var_name] = self.viz.line(X=np.array([x]),
                                                      Y=np.array([y]),
                                                      env=self.env,
                                                      opts=dict(
                                                          legend=[split_name],
                                                          title=title_name,
                                                          xlabel='Iteration',
                                                          ylabel=var_name))
        else:
            self.viz.line(X=np.array([x]),
                          Y=np.array([y]),
                          env=self.env,
                          win=self.line_plots[var_name],
                          name=split_name,
                          update='append')

    def draw(self, caption, images):
        if caption not in self.image_frames:
            self.image_frames[caption] = self.viz.images(images,
                                                         nrow=10,
                                                         padding=2,
                                                         env=self.env,
                                                         opts={
                                                             "caption":
                                                             caption,
                                                             "title": caption
                                                         })

        else:
            self.viz.images(images,
                            nrow=10,
                            padding=2,
                            env=self.env,
                            opts={
                                "caption": caption,
                                "title": caption
                            },
                            win=self.image_frames[caption])
コード例 #22
0
def main():
    mnist_train = datasets.MNIST('E:/ai_learning_resource/mnist',
                                 True,
                                 transform=transforms.Compose(
                                     [transforms.ToTensor()]),
                                 download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('E:/ai_learning_resource/mnist',
                                False,
                                transform=transforms.Compose(
                                    [transforms.ToTensor()]),
                                download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = next(iter(mnist_train))
    print('x:', x.shape)

    device = torch.device('cuda')
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = Visdom()

    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = -loss - 1.0 * kld
                loss = -elbo

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(epoch, 'loss', loss.item(), 'kld:', kld.item())

        x, _ = next(iter(mnist_test))
        x = x.to(device)
        with torch.no_grad():
            x_hat = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
コード例 #23
0
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main', port=8050, disable=False):
        self.disable = disable
        if self.disable:
            return

        try:
            self.viz = Visdom(port=port)
        except (ConnectionError, ConnectionRefusedError) as e:
            raise ConnectionError(
                "Visdom Server not running, please launch it with `visdom` in the terminal"
            )

        self.env = env_name
        self.plots = {}

    def clear(self):
        self.plots = {}

    def imshow(self, var_name, images):
        if self.disable:
            return

        if var_name not in self.plots:
            self.plots[var_name] = self.viz.images(images)
        else:
            self.viz.images(images, win=self.plots[var_name], env=self.env)

    def plot(self, window_id, variable, title, x, y, xlabel='epochs'):
        if self.disable:
            return

        if window_id not in self.plots:
            self.plots[window_id] = self.viz.line(X=np.array([x, x]),
                                                  Y=np.array([y, y]),
                                                  env=self.env,
                                                  opts=dict(legend=[variable],
                                                            title=title,
                                                            xlabel=xlabel,
                                                            ylabel=variable))
        else:
            self.viz.line(X=np.array([x]),
                          Y=np.array([y]),
                          env=self.env,
                          win=self.plots[window_id],
                          name=variable,
                          update='append')
コード例 #24
0
class Drawer:
    def __init__(self, name='evn'):
        self.vis = Visdom(env=name)
        self.name = name
        self.data = defaultdict(list)

    def add_value(self, key, value, update='append', name='', alpha=0.02):
        if hasattr(value, 'item'):
            value = value.item()
        self.data[key].append(value)
        ys = Drawer.smooth(np.array(
            self.data[key])) if len(self.data[key]) > 200 else np.array(
                self.data[key])
        self.vis.line(X=np.arange(len(ys)),
                      Y=ys,
                      opts={
                          'title': key,
                          'xlabel': 'step'
                      },
                      win=key,
                      env=self.name)

    @staticmethod
    def smooth(x, alpha=0.02):
        n = int(alpha * len(x))
        k = n // 2 if n % 2 != 0 else (n // 2) - 1
        ks = n - int(n % 2 == 0)
        cs = x[:ks].mean()
        cs2 = x[-ks:].mean()
        return np.convolve(np.pad(x,
                                  k,
                                  mode='constant',
                                  constant_values=(cs, cs2)),
                           np.ones(ks) / ks,
                           mode='valid')

    def add_images(self, key, images, nrow=None):
        if nrow is None:
            nrow = np.sqrt(len(images))
            nrow = math.ceil(nrow)
            nrow = int(nrow)
        assert images.ndim == 4
        self.data[key] = images
        self.vis.images(images,
                        opts={'title': key},
                        nrow=nrow,
                        win=key,
                        env=self.name)
コード例 #25
0
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}
    def plot(self, var_name, split_name, x, y, exp_name='test', env=None):
        if env is not None:
            print_env = env
        else:
            print_env = self.env
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=print_env, opts=dict(
                legend=[split_name],
                title=var_name,
                xlabel='Epochs',
                ylabel=var_name
            ))
        else:
            self.viz.updateTrace(X=np.array([x]), Y=np.array([y]), env=print_env, win=self.plots[var_name], name=split_name)

        if not os.path.exists('runs/%s/data/'%(exp_name)):
            os.makedirs('runs/%s/data/'%(exp_name))
        file = open('runs/%s/data/%s_%s_data.csv'%(exp_name, split_name, var_name), 'a')
        file.write('%d, %f\n'%(x, y))
        file.close()

    def plot_mask(self, masks, epoch):
        self.viz.bar(
            X=masks,
            env=self.env,
            opts=dict(
                stacked=True,
                title=epoch,
            )
        )

    def plot_image(self, image, epoch, exp_name='test'):
        self.viz.image(image, env=exp_name+'_img', opts=dict(
            caption=epoch,
            ))

    def plot_images(self, images, run_split, epoch, nrow, padding=2, exp_name='test'):
        self.viz.images(images, env=exp_name+'_img', nrow=nrow, padding=padding, opts=dict(
            caption='%s_%d'%(run_split, epoch),
            # title='Random images',
            jpgquality=100,
            ))
コード例 #26
0
ファイル: utilities.py プロジェクト: pablo1n7/SketchZoomsDeep
class VisdomLinePlotter(object):
    """Plots to Visdom"""

    def __init__(self, env_name='main', port=8097):
        self.viz = Visdom(port=port)
        self.env = env_name
        self.plots = {}
        self.scores_window = None
        self.image_window = None

    def plot(self, var_name, split_name, x, y, x_label='Epochs'):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict(
                legend=[split_name],
                title=var_name,
                xlabel=x_label,
                ylabel=var_name
            ))
        else:
            self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append')

    def close_window(self, var_name):
        self.viz.close(self.plots[var_name])
        del self.plots[var_name]
        
    def images(self, images):
        if self.image_window != None:
            self.viz.close(self.image_window)
            
        self.image_window = self.viz.images(images, nrow=3, env=self.env,
                                            opts=dict(nrow=2, title='Images Batch'))
コード例 #27
0
class VisdomLogger():
    """
    Logger that uses visdom to create learning curves
    Parameters
    ----------
    - env: str, name of the visdom environment
    - log_checkpoints: bool, whether to use checkpoints or epoch averages
        for training loss
    - legend: tuple, names of the different losses that will be plotted.
    """
    def __init__(self,
                 server='http://localhost',
                 port=8097):
        if Visdom is None:
            warnings.warn("Couldn't import visdom: `pip install visdom`")
        else:
            self.viz = Visdom(server=server, port=port)
            # self.viz.delete_env()

    def deleteWindow(self, win):
        self.viz.close(win=win)
        
    def appendLine(self, name, win, X, Y, xlabel='empty', ylabel='empty'):
        if xlabel == 'empty' or ylabel == 'empty':
            self.viz.line(X=X, Y=Y, win=win, name=name, update='append', opts=dict(title="Loss"))
        else:
            self.viz.line(X=X, Y=Y, win=win, name=name, update='append', opts=dict(title="Loss", xlabel=xlabel, ylabel=ylabel, showlegend=True))

    def plotLine(self, name, win, X, Y):
        self.viz.line(X=X, Y=Y, win=win, name=name)

    def plotImage(self, image, win, title="Image", caption="Just a Image"):
        self.viz.image(image,
                     win=win,
                     opts=dict(title=title, caption=caption))

    def plotImages(self, images, win, nrow, caption="Validation Output"):
        self.viz.images(images,
                        win=win,
                        nrow=nrow,
                        opts=dict(caption=caption))

    def plot3dScatter(self, point, win):
        print("Point is", point)
        self.viz.scatter(X = point,
                        win=win,
                        opts=dict(update='update'))
コード例 #28
0
class VisdomImgsPlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main', port=8097):
        self.vis = Visdom(port=port)
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, images, labels):
        if var_name not in self.plots:
            self.plots[var_name] = self.vis.images(images,
                                                   env=self.env,
                                                   opts=dict(title=var_name,
                                                             caption=labels))
        else:
            self.vis.images(images,
                            env=self.env,
                            win=self.plots[var_name],
                            opts=dict(title=var_name, caption=labels))
コード例 #29
0
ファイル: utils.py プロジェクト: qwerty1917/radar-ewc
class VisdomImagesPlotter(object):
    """Show images to Visdom"""
    def __init__(self, env_name: str, port: int):
        self.viz = Visdom(port=port)
        self.env = env_name
        self.frames = {}

    def draw(self, caption, images):
        if caption not in self.frames:
            self.frames[caption] = self.viz.images(images,
                                                   nrow=10,
                                                   padding=2,
                                                   opts={"caption": caption})

        else:
            self.viz.images(images,
                            nrow=10,
                            padding=2,
                            opts={"caption": caption},
                            win=self.frames[caption])
コード例 #30
0
class Reconstruction(object):
    def __init__(self, data_loader):
        self.images, _ = next(iter(data_loader))
        self.images = self.images.type(FloatTensor)
        assert self.images.shape[0] == 100
        self.viz = Visdom()
        self.win_original = None
        self.win_recon = None

    def __call__(self, trainer, gan_model):
        recon = gan_model.reconstruct(self.images)
        if self.win_original is None:
            self.win_original = self.viz.images(
                self.images, nrow=10, opts=dict(caption='Original Images'))
        if self.win_recon is None:
            self.win_recon = self.viz.images(
                recon, nrow=10, opts=dict(caption='Reconstructed Images'))
        else:
            self.viz.images(recon,
                            nrow=10,
                            win=self.win_recon,
                            opts=dict(caption='Reconstructed Images'))
コード例 #31
0
ファイル: demo.py プロジェクト: AbhinavJain13/visdom
    if os.path.isfile(videofile):
        viz.video(videofile=videofile)
except ImportError:
    print('Skipped video example')


# image demo
viz.image(
    np.random.rand(3, 512, 256),
    opts=dict(title='Random!', caption='How random.'),
)

# grid of images
viz.images(
    np.random.randn(20, 3, 64, 64),
    opts=dict(title='Random images', caption='How random.')
)

# scatter plots
Y = np.random.rand(100)
old_scatter = viz.scatter(
    X=np.random.rand(100, 2),
    Y=(Y[Y > 0] + 1.5).astype(int),
    opts=dict(
        legend=['Didnt', 'Update'],
        xtickmin=-50,
        xtickmax=50,
        xtickstep=0.5,
        ytickmin=-50,
        ytickmax=50,
        ytickstep=0.5,