コード例 #1
0
def visualize(epoch, input):  # C,target,fin1,fin2,fin3):
    from visdom import Visdom
    viz = Visdom()
    assert viz.check_connection()

    # print(tar.shape)

    # for c in range(C):
    # print(fin1.cpu()[n,c][np.newaxis, :].shape)
    # viz.image(
    #     fin1.cpu()[0,c][np.newaxis, :],
    #     opts=dict(title='fin1', caption='fin1'),
    # )
    # viz.image(
    #     fin2.cpu()[0,c][np.newaxis, :],
    #     opts=dict(title='fin2', caption='fin2'),
    # )
    # viz.image(
    #     fin3.cpu()[0,c][np.newaxis, :],
    #     opts=dict(title='fin3', caption='fin3'),
    # )

    viz.heatmap(input[0, 0],
                opts=dict(colormap='Electric',
                          title='Epoch-{} input'.format(epoch)))
    # viz.heatmap(X=target[0, c],
    #             opts=dict(colormap='Electric', title='Epoch-{} Points-{} target'.format(epoch, c)))
    # viz.heatmap(X=fin1[0, c],
    #             opts=dict(colormap='Electric', title='Epoch-{} Points-{} fin1'.format(epoch, c)))
    # viz.heatmap(X=fin2[0, c],
    #             opts=dict(colormap='Electric', title='Epoch-{} Points-{} fin2'.format(epoch, c)))
    # viz.heatmap(X=fin3[0, c],
    #             opts=dict(colormap='Electric', title='Epoch-{} Points-{} fin3'.format(epoch, c)))
    return
コード例 #2
0
ファイル: train_img.py プロジェクト: sid7954/838Project
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom(server='172.220.4.32', port='6006')
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, x, y, env=None):
        if env is not None:
            print_env = env
        else:
            print_env = self.env
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=print_env,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel='Epochs',
                                                           ylabel=var_name))
        else:
            self.viz.updateTrace(X=np.array([x]),
                                 Y=np.array([y]),
                                 env=print_env,
                                 win=self.plots[var_name],
                                 name=split_name)

    def plot_heatmap(self, map, epoch):
        self.viz.heatmap(X=map,
                         env=self.env,
                         opts=dict(title='activations {}'.format(epoch),
                                   xlabel='modules',
                                   ylabel='classes'))
コード例 #3
0
ファイル: visualizer.py プロジェクト: FLHonker/ZAQ-code
class VisdomPlotter(object):
    """ Visualizer
    """
    def __init__(self, port='13579', env='main'):
        self.cur_win = {}
        self.env = env
        self.visdom = Visdom(port=port, env=env)

    def add_scalar(self, win, x, y, opts=None, trace_name=None):
        """ Draw line
        """
        if not isinstance(x, list):
            x = [x]
        if not isinstance(y, list):
            y = [y]
        default_opts = {'title': win}
        if opts is not None:
            default_opts.update(opts)
        update = 'append' if win is not None else None
        self.visdom.line(X=x,
                         Y=y,
                         opts=default_opts,
                         win=win,
                         env=self.env,
                         update=update,
                         name=trace_name)

    def add_image(self, win, img, opts=None):
        """ vis image in visdom
        """
        default_opts = dict(title=win)
        if opts is not None:
            default_opts.update(opts)
        self.visdom.image(img=img, win=win, opts=default_opts, env=self.env)

    def add_table(self, win, tbl, opts=None):
        tbl_str = "<table width=\"100%\"> "
        tbl_str += "<tr> \
                 <th>[Key]</th> \
                 <th>[Value]</th> \
                 </tr>"

        for k, v in tbl.items():
            tbl_str += "<tr> \
                       <td>%s</td> \
                       <td>%s</td> \
                       </tr>" % (k, v)
        tbl_str += "</table>"

        default_opts = {'title': win}
        if opts is not None:
            default_opts.update(opts)
        self.visdom.text(tbl_str, win=win, env=self.env, opts=default_opts)

    def add_heatmap(self, win, X, opts=None):
        default_opts = {'title': win, 'xmin': 0, 'xmax': 1}
        if opts is not None:
            default_opts.update(opts)
        self.visdom.heatmap(X=X, win=win, opts=default_opts, env=self.env)
コード例 #4
0
class VisdomLogger:
    def __init__(self, visdom_env='main', log_every=10, prefix=''):
        self.vis = None
        self.log_every = log_every
        self.prefix = prefix
        if visdom_env is not None:
            self.vis = Visdom(env=visdom_env)
            self.vis.close()

    def on_batch_end(self, state):
        iters = state['iters']
        if self.log_every != -1 and iters % self.log_every == 0:
            self.log(iters, state['metrics'])

    def on_epoch_end(self, state):
        self.log(state['iters'], state['metrics'])

    def log(self, iters, xs, store_history=[]):
        if self.vis is None:
            return

        for name, x in xs.items():
            name = self.prefix + name
            if isinstance(x, (float, int)):
                self.vis.line(X=[iters],
                              Y=[x],
                              update='append',
                              win=name,
                              opts=dict(title=name),
                              name=name)
            elif isinstance(x, str):
                self.vis.text(x, win=name, opts=dict(title=name))
            elif isinstance(x, torch.Tensor):
                if x.numel() == 1:
                    self.vis.line(X=[iters],
                                  Y=[x.item()],
                                  update='append',
                                  win=name,
                                  opts=dict(title=name),
                                  name=name)
                elif x.dim() == 2:
                    self.vis.heatmap(x, win=name, opts=dict(title=name))
                elif x.dim() == 3:
                    self.vis.image(x,
                                   win=name,
                                   opts=dict(title=name,
                                             store_history=name
                                             in store_history))
                elif x.dim() == 4:
                    self.vis.images(x,
                                    win=name,
                                    opts=dict(title=name,
                                              store_history=name
                                              in store_history))
                else:
                    assert False, "incorrect tensor dim"
            else:
                assert False, "incorrect type " + x.__class__.__name__
コード例 #5
0
ファイル: asen.py プロジェクト: realcatking/asen
class VisdomLinePlotter(object):
    # Plots to Visdom
    def __init__(self, env_name='main'):
        self.viz = Visdom(port=args.visdom_port)
        self.env = env_name
        self.plots = {}

    # plot curve graph
    def plot(self, var_name, split_name, x, y):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel='Epochs',
                                                           ylabel=var_name))
        else:
            self.viz.line(X=np.array([x]),
                          Y=np.array([y]),
                          env=self.env,
                          win=self.plots[var_name],
                          name=split_name,
                          update='append')

    # plot attention map
    def plot_attention(self, imgs, heatmaps, tasks, alpha=0.5):
        global meta

        for i in range(len(tasks)):
            heatmap = heatmaps[i]
            heatmap = cv2.resize(heatmap, (224, 224),
                                 interpolation=cv2.INTER_CUBIC)
            heatmap = np.maximum(heatmap, 0)
            heatmap /= np.max(heatmap)
            heatmap_marked = np.uint8(cm.gist_rainbow(heatmap)[..., :3] * 255)
            heatmap_marked = cv2.cvtColor(heatmap_marked, cv2.COLOR_BGR2RGB)
            heatmap_marked = np.uint8(imgs[i] * alpha + heatmap_marked *
                                      (1. - alpha))
            heatmap_marked = heatmap_marked.transpose([2, 0, 1])

            win_name = 'img %d - %s' % (i, meta.data['ATTRIBUTES'][tasks[i]])
            if win_name not in self.plots:
                self.plots[win_name] = self.viz.image(
                    heatmap_marked, env=self.env, opts=dict(title=win_name))
                self.plots[win_name + 'heatmap'] = self.viz.heatmap(
                    heatmap, env=self.env, opts=dict(title=win_name))
            else:
                self.viz.image(heatmap_marked,
                               env=self.env,
                               win=self.plots[win_name],
                               opts=dict(title=win_name))
                self.viz.heatmap(heatmap,
                                 env=self.env,
                                 win=self.plots[win_name + 'heatmap'],
                                 opts=dict(title=win_name))
コード例 #6
0
def plot_img(X=None, win=None, env=None, plot=None, port=_port):
    if plot is None:
        plot = Visdom(port=port)
    if X.ndim == 2:
        plot.heatmap(X=np.flipud(X), win=win,
                     opts=dict(title=win), env=env)
    elif X.ndim == 3:
        # X is BWC
        norm_img = normalize_img(X)
        plot.image(norm_img.transpose(2, 0, 1), win=win,
                   opts=dict(title=win), env=env)
コード例 #7
0
class VisdomHeatmap():
    """Plots to Visdom"""
    def __init__(self, env_name='main', port=8097):
        self.viz = Visdom(port=port)
        self.env = env_name

    def plot(self, title_name, x, class_list):

        self.viz.heatmap(X=np.flipud(x),
                         env=self.env,
                         opts=dict(title=title_name,
                                   columnnames=class_list,
                                   rownames=list(reversed(class_list))))
コード例 #8
0
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}

    def plot(self, var_name, split_name, x, y):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 name=split_name,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel='Epochs',
                                                           ylabel=var_name))
        else:
            self.viz.line(X=np.array([x, x]),
                          Y=np.array([y, y]),
                          win=self.plots[var_name],
                          env=self.env,
                          name=split_name,
                          update='append')

    def image(self, map1, map2):
        map1 = map1.sum(dim=0)
        map1 = map1.sum(dim=0)
        map2 = map2.sum(dim=0)
        map2 = map2.sum(dim=0)
        # for i in enumerate(map1.data):
        self.viz.heatmap(X=map1.data, opts=dict(colormap='Viridis'))
        #for i in enumerate(map2.data):
        self.viz.heatmap(X=map2.data, opts=dict(colormap='Viridis'))

    def weight(self, x, state):
        # the following code can get the name of each layer
        # for k, v in params.items():
        #     print(k)

        # maybe we need to change the layer's name
        y = state['fc.bias'].view(-1, 1)
        s = np.array([x, x])
        for i in range(0, y.shape[0] - 1):
            s = np.column_stack((s, np.array([x, x])))
        w = np.array([y, y])
        if x == 1:
            self.viz.line(X=s, Y=w, win='weights')
        self.viz.line(X=s, Y=w, win='weights', update='append')
コード例 #9
0
ファイル: graphic.py プロジェクト: HelenGuohx/cv-ferattn-code
class HeatMapVisdom(object):
    """Heat Map to Visdom"""
    def __init__(self, env_name='main', heatsize=None):
        self.vis = Visdom(use_incoming_socket=False)
        self.env = env_name
        self.hmaps = {}
        self.heatsize = heatsize

    def show(self, title, image):

        if self.heatsize:
            image = cv2.resize(image,
                               self.heatsize,
                               interpolation=cv2.INTER_LINEAR)

        if title not in self.hmaps:
            self.hmaps[title] = self.vis.heatmap(image,
                                                 env=self.env,
                                                 opts=dict(title=title))
        else:
            self.vis.heatmap(image,
                             env=self.env,
                             win=self.hmaps[title],
                             opts=dict(title=title))
コード例 #10
0
def display_timeseries(strumodel,
                       BatchData,
                       BatchLabel,
                       plot=None,
                       name='default',
                       port=_port):
    if plot is None:
        plot = Visdom(port=port)
    B, T, C, W, H = BatchData.shape
    pred_T = BatchLabel.shape[1]

    batch_id = random.randint(0, B - 1)
    intv = 9

    data_len = min(10, T)
    pred_len = min(10, pred_T)
    #sel_t = random.randint(0, T-len)

    inputdata = BatchData[batch_id:batch_id + 1, ...]
    prediction = strumodel.predict(inputdata)
    labeldata = BatchLabel[batch_id:batch_id + 1, ...]

    images_shape = (W, int(data_len * (H + intv) - intv))
    label_shape = (W, int(pred_len * (H + intv) - intv))
    batch_content = np.zeros(images_shape)
    predict_content = np.zeros(label_shape)
    label_content = np.zeros(label_shape)
    #fill the image
    for idx in range(data_len):
        rs, cs = 0, idx * (H + intv)
        batch_content[rs:rs + W, cs:cs + H] = inputdata[0, idx, 0]

    for idx in range(pred_len):
        rs, cs = 0, idx * (H + intv)
        predict_content[rs:rs + W, cs:cs + H] = prediction[0, idx, 0]
        label_content[rs:rs + W, cs:cs + H] = labeldata[0, idx, 0]

    diff_abs = np.abs(
        predict_content -
        label_content)  # we do this intentionally to do sanity check
    #diff_ratio = diff_abs/(np.abs(label_content) + 1)

    #imshow(predict_content)
    #imshow(label_content)
    plot.heatmap(X=np.flipud(batch_content),
                 win=name + '_OriginalImage',
                 opts=dict(title=name + '_OriginalImage'),
                 env=name)
    plot.heatmap(X=np.flipud(label_content),
                 win=name + '_GroundTruth',
                 opts=dict(title=name + '_GroundTruth'),
                 env=name)
    plot.heatmap(X=np.flipud(predict_content),
                 win=name + '_Prediction',
                 opts=dict(title=name + '_Prediction'),
                 env=name)
コード例 #11
0
class VisdomPlotter:
    """
    A Visdom based plotter, to plot aggregated metrics.

    How to use:
    ------------
    (1) Start the server with:
            python -m visdom.server
    (2) Then, in your browser, you can go to:
            http://localhost:8097
    """
    def __init__(self, experiment_env, server='http://localhost', port=8097):
        self.server = server
        self.port = port
        self.viz = Visdom(
            server=server,
            port=port)  # Connect to Visdom server on server / port
        if not self.start_visdom_server():
            raise ValueError('Failed to launch Visdom server at %r:%r' %
                             (server, port))

        if experiment_env in self.viz.get_env_list():
            self.viz.delete_env(
                experiment_env)  # Clear previous runs with same id
        self.experiment_env = experiment_env
        self.plots = {}

    def start_visdom_server(self):
        is_visdom_server_connected = self.viz.check_connection(
            timeout_seconds=1)  # Ping if it's already on..
        if not is_visdom_server_connected:
            interpreter_path = sys.executable
            os.system(interpreter_path + ' -m visdom.server &')
            is_visdom_server_connected = self.viz.check_connection(
                timeout_seconds=35)
        return is_visdom_server_connected

    def plot_single_metric(self, metric, line_id, title, epoch, value):

        if metric not in self.plots:
            self.plots[metric] = self.viz.line(X=np.array([epoch, epoch]),
                                               Y=np.array([value, value]),
                                               env=self.experiment_env,
                                               opts=dict(legend=[line_id],
                                                         title=title,
                                                         xlabel='Epochs',
                                                         ylabel=metric))
        else:
            self.viz.line(X=np.array([epoch]),
                          Y=np.array([value]),
                          env=self.experiment_env,
                          win=self.plots[metric],
                          name=line_id,
                          update='append')

    def plot_confusion_matrix(self, metric, matrix, label_classes):
        if metric not in self.plots:
            self.plots[metric] = self.viz.heatmap(
                X=matrix,
                env=self.experiment_env,
                opts=dict(columnnames=label_classes, rownames=label_classes))
        else:
            self.viz.heatmap(X=matrix,
                             env=self.experiment_env,
                             win=self.plots[metric],
                             opts=dict(columnnames=label_classes,
                                       rownames=label_classes))

    def plot_images(self, images_bchw):
        self.viz.images(images_bchw)

    def plot_aggregated_metrics(self, metrics, epoch):

        for metric in metrics.metrics:
            title = metrics.metric_to_title[metric]
            value = metrics[epoch][metric]

            if metric == 'confusion_matrix':
                label_classes = metrics.label_classes
                self.plot_confusion_matrix(metric, value, label_classes)
            else:
                if hasattr(value, 'shape') and value.size > 1:
                    for idx, dim_val in enumerate(value):
                        line_id = metrics.label_classes[idx]
                        self.plot_single_metric(metric, line_id, title, epoch,
                                                dim_val)
                else:
                    line_id = metrics.data_type
                    self.plot_single_metric(metric, line_id, title, epoch,
                                            value)
コード例 #12
0
c1 = scipy.misc.imresize(
    arr=grid_image[:, :, 0],  # obstacles
    size=[grid_imsize, grid_imsize, 1],
    interp='nearest')
c2 = scipy.misc.imresize(
    arr=grid_image[:, :, 1],  # goal
    size=[grid_imsize, grid_imsize, 1],
    interp='nearest')
c3 = scipy.misc.imresize(
    arr=np.zeros([gridsize, gridsize]),  # nothing
    size=[grid_imsize, grid_imsize, 1],
    interp='nearest')
# Combine as a RGB image
grid_image = np.flipud(np.stack([c1, c2, c3], 2))

grid_image = grid_image.transpose([2, 0, 1])  # TEMP

# Create a visdom object
vis = Visdom()

# Image for grid image
vis.image(grid_image, opts=dict(title='Grid world', caption='Test'))

# Heatmap for reward image
vis.heatmap(reward_image.squeeze(), opts=dict(colormap='Electric'))

# Heatmap for value image
for img in value_images:
    vis.heatmap(img.squeeze(), opts=dict(colormap='Electric'))

print('Finshed. Please open visdom page.')
コード例 #13
0
ファイル: callbacks.py プロジェクト: wayne980/Hyperbolic_ZSL
class Callback(object):
    """
    Used to log/visualize the evaluation metrics during training. The values are stored at the end of each epoch.
    """
    def __init__(self, metrics):
        """
        Args:
            metrics : a list of callbacks. Possible values:
                "CoherenceMetric"
                "PerplexityMetric"
                "DiffMetric"
                "ConvergenceMetric"
        """
        # list of metrics to be plot
        self.metrics = metrics

    def set_model(self, model):
        """
        Save the model instance and initialize any required variables which would be updated throughout training
        """
        self.model = model
        self.previous = None
        # check for any metric which need model state from previous epoch
        if any(
                isinstance(metric, (DiffMetric, ConvergenceMetric))
                for metric in self.metrics):
            self.previous = copy.deepcopy(model)
            # store diff diagonals of previous epochs
            self.diff_mat = Queue()
        if any(metric.logger == "visdom" for metric in self.metrics):
            if not VISDOM_INSTALLED:
                raise ImportError("Please install Visdom for visualization")
            self.viz = Visdom()
            # store initial plot windows of every metric (same window will be updated with increasing epochs)
            self.windows = []
        if any(metric.logger == "shell" for metric in self.metrics):
            # set logger for current topic model
            self.log_type = logging.getLogger('gensim.models.ldamodel')

    def on_epoch_end(self, epoch, topics=None):
        """
        Log or visualize current epoch's metric value

        Args:
            epoch : current epoch no.
            topics : topic distribution from current epoch (required for coherence of unsupported topic models)
        """
        # stores current epoch's metric values
        current_metrics = {}

        # plot all metrics in current epoch
        for i, metric in enumerate(self.metrics):
            label = str(metric)
            value = metric.get_value(topics=topics,
                                     model=self.model,
                                     other_model=self.previous)

            current_metrics[label] = value

            if metric.logger == "visdom":
                if epoch == 0:
                    if value.ndim > 0:
                        diff_mat = np.array([value])
                        viz_metric = self.viz.heatmap(X=diff_mat.T,
                                                      env=metric.viz_env,
                                                      opts=dict(
                                                          xlabel='Epochs',
                                                          ylabel=label,
                                                          title=label))
                        # store current epoch's diff diagonal
                        self.diff_mat.put(diff_mat)
                        # saving initial plot window
                        self.windows.append(copy.deepcopy(viz_metric))
                    else:
                        viz_metric = self.viz.line(Y=np.array([value]),
                                                   X=np.array([epoch]),
                                                   env=metric.viz_env,
                                                   opts=dict(xlabel='Epochs',
                                                             ylabel=label,
                                                             title=label))
                        # saving initial plot window
                        self.windows.append(copy.deepcopy(viz_metric))
                else:
                    if value.ndim > 0:
                        # concatenate with previous epoch's diff diagonals
                        diff_mat = np.concatenate(
                            (self.diff_mat.get(), np.array([value])))
                        self.viz.heatmap(X=diff_mat.T,
                                         env=metric.viz_env,
                                         win=self.windows[i],
                                         opts=dict(xlabel='Epochs',
                                                   ylabel=label,
                                                   title=label))
                        self.diff_mat.put(diff_mat)
                    else:
                        self.viz.updateTrace(Y=np.array([value]),
                                             X=np.array([epoch]),
                                             env=metric.viz_env,
                                             win=self.windows[i])

            if metric.logger == "shell":
                statement = "".join(
                    ("Epoch ", str(epoch), ": ", label, " estimate: ",
                     str(value)))
                self.log_type.info(statement)

        # check for any metric which need model state from previous epoch
        if isinstance(metric, (DiffMetric, ConvergenceMetric)):
            self.previous = copy.deepcopy(self.model)

        return current_metrics
コード例 #14
0
ファイル: copy_task.py プロジェクト: ixaxaar/pytorch-dni
                raise Exception('nan Loss')

        if summarize and debug_enabled:
            loss = np.mean(last_save_losses)
            # print(input_data)
            # print("1111111111111111111111111111111111111111111111")
            # print(target_output)
            # print('2222222222222222222222222222222222222222222222')
            # print(F.relu6(output))
            last_save_losses = []

            if args.memory_type == 'dnc':
                viz.heatmap(v['memory'],
                            opts=dict(xtickstep=10,
                                      ytickstep=2,
                                      title='Memory, t: ' + str(epoch) +
                                      ', loss: ' + str(loss),
                                      ylabel='layer * time',
                                      xlabel='mem_slot * mem_size'))

            if args.memory_type == 'dnc':
                viz.heatmap(v['link_matrix'][-1].reshape(
                    args.mem_slot, args.mem_slot),
                            opts=dict(xtickstep=10,
                                      ytickstep=2,
                                      title='Link Matrix, t: ' + str(epoch) +
                                      ', loss: ' + str(loss),
                                      ylabel='mem_slot',
                                      xlabel='mem_slot'))
            elif args.memory_type == 'sdnc':
                viz.heatmap(v['link_matrix'][-1].reshape(args.mem_slot, -1),
mdl = load_model()
dg = dataloader(batch_size,path,phase)
iteartion_per_epoch = len(dg.dataset)//batch_size;
optimizer = optim.Adam(mdl.parameters(), lr=0.001)
schedular = StepLR(optimizer, step_size=10000, gamma=0.1)
viz = Visdom()
counter=0;
for i in range(iteartion_per_epoch*10):
	image, label = dg.create_batch();
#	image = image.cuda()
#	label = label.cuda()
	label = label.type(torch.FloatTensor).cuda()
	optimizer.zero_grad()
	output = mdl(Variable(image))
	loss = torch.mean(loss_function(output[:,0,:,:],label[:,0,:,:])+loss_function(output[:,1,:,:],label[:,1,:,:])+loss_function(output[:,2,:,:],label[:,2,:,:]))
	print("loss",loss)
	loss.backward()
	optimizer.step()
	schedular.step()
	print('iteration :',i,'loss ',loss.item())
	if i%100==0 and i>0:
		viz.image(convert(image[0,...]),"image")
		viz.heatmap(convert(output[0,0,:,:]),"heatmap1")
		viz.heatmap(convert(output[0,1,:,:]),"heatmap2")
		viz.heatmap(convert(output[0,2,:,:]),"heatmap3")
	if i%5000==0 and i>0:
		torch.save(mdl.state_dict(), './weight/'+str(counter)+'.pt')
	counter+=1;


コード例 #16
0
ファイル: logger.py プロジェクト: zc280330/pytorch-asr
class VisdomLogger:
    def __init__(self,
                 host='127.0.0.1',
                 port=8097,
                 env='main',
                 log_path=None,
                 rank=None):
        from visdom import Visdom
        logger.debug(f"using visdom on http://{host}:{port} env={env}")
        self.env = env
        self.rank = rank
        self.viz = Visdom(server=f"http://{host}",
                          port=port,
                          env=env,
                          log_to_filename=log_path)
        self.windows = dict()
        # if prev log exists
        if log_path is not None and log_path.exists() and (rank is None
                                                           or rank == 0):
            self.viz.replay_log(log_path)

    def _get_win(self, title, type):
        import json
        win_data = json.loads(self.viz.get_window_data(win=None, env=self.env))
        wins = [(w, v) for w, v in win_data.items()
                if v['title'] == title and v['type'] == type]
        if wins:
            handle, value = sorted(wins, key=lambda x: x[0])[0]
            return handle, value['content']
        else:
            return None, None

    def _get_rank0_win(self, title, type):
        if self.rank is not None and self.rank > 0:
            # wait and fetch the window handle until rank=0 client generates new window
            for _ in range(10):
                handle, content = self._get_win(title, type)
                if handle is not None:
                    return handle, content
                time.sleep(0.5)
            else:
                logger.error(
                    "couldn't get a proper window handle from the visdom server"
                )
                raise RuntimeError
        else:
            return self._get_win(title, type)

    def _new_window(self, cmd, title, **cmd_args):
        if cmd == self.viz.images:
            types = ("image", None)
        elif cmd == self.viz.scatter or cmd == self.viz.line:
            types = ("plot", "scatter")
        elif cmd == self.viz.heatmap:
            types = ("plot", "heatmap")
        else:
            types = ("plot", None)

        handle, content = self._get_rank0_win(title, types[0])

        if handle is None:
            if "opts" in cmd_args:
                cmd_args['opts'].update({
                    "title": title,
                })
            else:
                cmd_args['opts'] = {
                    "title": title,
                }
            if types == ("plot", "scatter"):
                name = f"1_{self.rank}" if self.rank is not None else "1"
                handle = cmd(name=name, **cmd_args)
            else:
                name = None
                handle = cmd(**cmd_args)
        else:
            if types == ("plot", "scatter"):
                name = max([
                    int(x['name'].partition('_')[0]) for x in content['data']
                ])
                name = f"{name+1}_{self.rank}" if self.rank is not None else f"{name+1}"
                cmd(win=handle, name=name, update="append", **cmd_args)
            else:
                name = None
                handle = cmd(win=handle, **cmd_args)
        self.windows[title] = {
            'handle': handle,
            'name': name,
            'opts': cmd_args["opts"],
        }

    def add_point(self, title, x, y, **kwargs):
        X, Y = torch.FloatTensor([
            x,
        ]), torch.FloatTensor([
            y,
        ])
        if title not in self.windows:
            cmd = self.viz.line
            self._new_window(cmd, title, X=X, Y=Y, opts=kwargs)
        else:
            self.windows[title]['opts'].update(kwargs)
            handle = self.windows[title]['handle']
            name = self.windows[title]['name']
            opts = self.windows[title]['opts']
            self.viz.line(win=handle,
                          update='append',
                          Y=Y,
                          X=X,
                          name=name,
                          opts=opts)

    def plot_heatmap(self, title, tensor, **kwargs):
        if title not in self.windows:
            cmd = self.viz.heatmap
            self._new_window(cmd, title, X=tensor, opts=kwargs)
        else:
            self.windows[title]['opts'].update(kwargs)
            handle = self.windows[title]['handle']
            opts = self.windows[title]['opts']
            self.viz.heatmap(win=handle, X=tensor, opts=opts)

    def plot_images(self, title, tensor, nrow, **kwargs):
        if title not in self.windows:
            cmd = self.viz.images
            self._new_window(cmd, title, tensor=tensor, nrow=nrow, opts=kwargs)
        else:
            self.windows[title]['opts'].update(kwargs)
            handle = self.windows[title]['handle']
            opts = self.windows[title]['opts']
            self.viz.images(win=handle, tensor=tensor, nrow=nrow, opts=opts)
コード例 #17
0
ファイル: callbacks.py プロジェクト: jMonteroMunoz/gensim
class Callback(object):
    """
    Used to log/visualize the evaluation metrics during training. The values are stored at the end of each epoch.
    """
    def __init__(self, metrics):
        """
        Args:
            metrics : a list of callbacks. Possible values:
                "CoherenceMetric"
                "PerplexityMetric"
                "DiffMetric"
                "ConvergenceMetric"
        """
        # list of metrics to be plot
        self.metrics = metrics

    def set_model(self, model):
        """
        Save the model instance and initialize any required variables which would be updated throughout training
        """
        self.model = model
        self.previous = None
        # check for any metric which need model state from previous epoch
        if any(isinstance(metric, (DiffMetric, ConvergenceMetric)) for metric in self.metrics):
            self.previous = copy.deepcopy(model)
            # store diff diagonals of previous epochs
            self.diff_mat = Queue()
        if any(metric.logger == "visdom" for metric in self.metrics):
            if not VISDOM_INSTALLED:
                raise ImportError("Please install Visdom for visualization")
            self.viz = Visdom()
            # store initial plot windows of every metric (same window will be updated with increasing epochs)
            self.windows = []
        if any(metric.logger == "shell" for metric in self.metrics):
            # set logger for current topic model
            self.log_type = logging.getLogger('gensim.models.ldamodel')

    def on_epoch_end(self, epoch, topics=None):
        """
        Log or visualize current epoch's metric value

        Args:
            epoch : current epoch no.
            topics : topic distribution from current epoch (required for coherence of unsupported topic models)
        """
        # stores current epoch's metric values
        current_metrics = {}

        # plot all metrics in current epoch
        for i, metric in enumerate(self.metrics):
            label = str(metric)
            value = metric.get_value(topics=topics, model=self.model, other_model=self.previous)

            current_metrics[label] = value

            if metric.logger == "visdom":
                if epoch == 0:
                    if value.ndim > 0:
                        diff_mat = np.array([value])
                        viz_metric = self.viz.heatmap(
                            X=diff_mat.T, env=metric.viz_env, opts=dict(xlabel='Epochs', ylabel=label, title=label)
                        )
                        # store current epoch's diff diagonal
                        self.diff_mat.put(diff_mat)
                        # saving initial plot window
                        self.windows.append(copy.deepcopy(viz_metric))
                    else:
                        viz_metric = self.viz.line(
                            Y=np.array([value]), X=np.array([epoch]), env=metric.viz_env,
                            opts=dict(xlabel='Epochs', ylabel=label, title=label)
                        )
                        # saving initial plot window
                        self.windows.append(copy.deepcopy(viz_metric))
                else:
                    if value.ndim > 0:
                        # concatenate with previous epoch's diff diagonals
                        diff_mat = np.concatenate((self.diff_mat.get(), np.array([value])))
                        self.viz.heatmap(
                            X=diff_mat.T, env=metric.viz_env, win=self.windows[i],
                            opts=dict(xlabel='Epochs', ylabel=label, title=label)
                        )
                        self.diff_mat.put(diff_mat)
                    else:
                        self.viz.updateTrace(
                            Y=np.array([value]), X=np.array([epoch]), env=metric.viz_env, win=self.windows[i]
                        )

            if metric.logger == "shell":
                statement = "".join(("Epoch ", str(epoch), ": ", label, " estimate: ", str(value)))
                self.log_type.info(statement)

        # check for any metric which need model state from previous epoch
        if isinstance(metric, (DiffMetric, ConvergenceMetric)):
            self.previous = copy.deepcopy(self.model)

        return current_metrics
コード例 #18
0
                                   opts=dict(title='VALIDATION_ACCURACY',
                                             xlabel='EPOCH'))

    colnames = train_query_list[0].split()
    rownames = train_layout_list[0].split()
    colnames = [
        '{}{}'.format(chr(24) * i, colnames[i]) for i in range(len(colnames))
    ]
    rownames = [
        '{}{}'.format(chr(24) * i, rownames[i][1:])
        for i in range(len(rownames))
    ]

    attention_heatmap = viz.heatmap(X=np.zeros((5, 9)),
                                    opts=dict(title='ATTENTION HEATMAP',
                                              columnnames=colnames,
                                              rownames=rownames,
                                              colormap='Jet'))

    # Change to use named_parameters() and autoinference, this is hacky
    w1w = viz.line(X=np.array([0]),
                   Y=np.array([0]),
                   opts=dict(title='seq2seq: W1.weight', xlabel='EPOCH'))
    w1g = viz.line(X=np.array([0]),
                   Y=np.array([0]),
                   opts=dict(title='seq2seq: W1.grad', xlabel='EPOCH'))

    w2w = viz.line(X=np.array([0]),
                   Y=np.array([0]),
                   opts=dict(title='seq2seq: W2.weight', xlabel='EPOCH'))
    w2g = viz.line(X=np.array([0]),
コード例 #19
0
ファイル: callbacks.py プロジェクト: AriMKatz/Torchelie
class VisdomLogger(tu.AutoStateDict):
    """
    Log metrics to Visdom. It logs scalars and scalar tensors as plots, 3D and
    4D tensors as images, and strings as HTML.

    Args:
        visdom_env (str): name of the target visdom env
        log_every (int): batch logging freq. -1 logs on epoch ends only.
        prefix (str): prefix for all metrics name
    """
    def __init__(self, visdom_env='main', log_every=10, prefix=''):
        super(VisdomLogger, self).__init__(except_names=['vis'])
        self.vis = None
        self.log_every = log_every
        self.prefix = prefix
        if visdom_env is not None:
            self.vis = Visdom(env=visdom_env)
            self.vis.close()

    def on_batch_start(self, state):
        iters = state['iters']
        state['visdom_will_log'] = (self.log_every != -1
                                    and iters % self.log_every == 0)

    @torch.no_grad()
    def on_batch_end(self, state):
        iters = state['iters']
        if self.log_every != -1 and iters % self.log_every == 0:
            self.log(iters, state['metrics'])

    def on_epoch_end(self, state):
        self.log(state['iters'], state['metrics'])

    def log(self, iters, xs, store_history=[]):
        if self.vis is None:
            return

        for name, x in xs.items():
            name = self.prefix + name
            if isinstance(x, (float, int)):
                self.vis.line(X=[iters],
                              Y=[x],
                              update='append',
                              win=name,
                              opts=dict(title=name),
                              name=name)
            elif isinstance(x, str):
                self.vis.text(x, win=name, opts=dict(title=name))
            elif isinstance(x, torch.Tensor):
                if x.numel() == 1:
                    self.vis.line(X=[iters],
                                  Y=[x.item()],
                                  update='append',
                                  win=name,
                                  opts=dict(title=name),
                                  name=name)
                elif x.dim() == 2:
                    self.vis.heatmap(x, win=name, opts=dict(title=name))
                elif x.dim() == 3:
                    self.vis.image(x,
                                   win=name,
                                   opts=dict(title=name,
                                             store_history=name
                                             in store_history))
                elif x.dim() == 4:
                    self.vis.images(x,
                                    win=name,
                                    opts=dict(title=name,
                                              store_history=name
                                              in store_history))
                else:
                    assert False, "incorrect tensor dim"
            else:
                assert False, "incorrect type " + x.__class__.__name__
コード例 #20
0
viz.bar(
    X=np.abs(np.random.rand(5, 3)),  # 5个列,每列有3部分组成
    opts={
        'stacked': True,
        'legend': ['A', 'B', 'C'],
        'rownames': ['2012', '2013', '2014', '2015', '2016']
    })

viz.bar(X=np.random.rand(20, 3),
        opts={
            'stacked': False,
            'legend': ['America', 'Britsh', 'China']
        })

# 热力图,地理图,表面图
viz.heatmap(X=np.outer(np.arange(1, 6), np.arange(1, 11)),
            opts={
                'columnnames':
                ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'],
                'rownames': ['y1', 'y2', 'y3', 'y4', 'y5'],
                'colormap': 'Electric'
            })

# 地表图
x = np.tile(np.arange(1, 101), (100, 1))
y = x.transpose()
X = np.exp((((x - 50)**2) + ((y - 50)**2)) / -(20.0**2))
viz.contour(X=X, opts=dict(colormap='Viridis'))

# 表面图
viz.surf(X=X, opts={'colormap': 'Hot'})
コード例 #21
0
ファイル: main.py プロジェクト: ixaxaar/awd-dnc-lm
                                  weight_decay=args.wdecay)  # 0.01
        elif args.optim == 'adagrad':
            optimizer = optim.Adagrad(model.parameters(), lr=lr)
        elif args.optim == 'adadelta':
            optimizer = optim.Adadelta(model.parameters(), lr=lr)

        epoch_start_time = time.time()
        v = train()

        val_loss2 = evaluate(val_data)

        if False:
            viz.heatmap(v[0]['memory'],
                        opts=dict(xtickstep=10,
                                  ytickstep=2,
                                  title='Memory, t: ' + str(epoch) +
                                  ', ppx: ' + str(math.exp(val_loss2)),
                                  ylabel='layer * time',
                                  xlabel='mem_slot * mem_size'))

            viz.heatmap(v[0]['link_matrix'],
                        opts=dict(xtickstep=10,
                                  ytickstep=2,
                                  title='Link Matrix, t: ' + str(epoch) +
                                  ', ppx: ' + str(math.exp(val_loss2)),
                                  ylabel='layer * time',
                                  xlabel='mem_slot * mem_slot'))

            viz.heatmap(v[0]['precedence'],
                        opts=dict(xtickstep=10,
                                  ytickstep=2,
コード例 #22
0
ファイル: MapClient.py プロジェクト: phate09/bham_brigade_py
class SampleHazardDetector(IDataReceived):
    def __init__(self, tcpClient):
        DEFAULT_PORT = 8097
        DEFAULT_HOSTNAME = "http://localhost"
        parser = argparse.ArgumentParser(description='Demo arguments')
        parser.add_argument('-port',
                            metavar='port',
                            type=int,
                            default=DEFAULT_PORT,
                            help='port the visdom server is running on.')
        parser.add_argument(
            '-server',
            metavar='server',
            type=str,
            default=DEFAULT_HOSTNAME,
            help='Server address of the target to run the demo on.')
        FLAGS = parser.parse_args()
        self.viz = Visdom(port=FLAGS.port, server=FLAGS.server)
        self.fake_point = False  # whether import the boundaries of the fire from the xml
        self.last_refresh = 0
        self.new_points_detected = False

        assert self.viz.check_connection(
            timeout_seconds=3
        ), 'No connection could be formed quickly, remember to run \'visdom\' in the terminal'

        self.__client = tcpClient
        self.__uavsLoiter = {}
        self.__estimatedHazardZone = Polygon()
        self.filename = None
        self.heatmap = np.zeros(
            (SIZE_LAT, SIZE_LONG))  # the places where the fire was detected
        self.last_detected = np.zeros(
            (SIZE_LAT, SIZE_LONG)
        )  # the time at which the fire was last detected (or not) in that cell
        self.smooth = np.zeros((SIZE_LAT, SIZE_LONG))
        self.drones_status = {}
        self.current_time = 0
        self.force_recompute_times = []
        self.communication_channel = protobuf.communication_client.CommunicationChannel(
        )
        # self.belief_model = BeliefModel()

    def load_scenario(self, filename):
        print("loading scenario")
        self.scenario = minidom.parse(filename)
        simulation_view_node = self.scenario.getElementsByTagName(
            'SimulationView')
        self.latitude = float(
            simulation_view_node[0].attributes['Latitude'].value)
        self.longitude = float(
            simulation_view_node[0].attributes['Longitude'].value)
        self.long_extent = float(
            simulation_view_node[0].attributes['LongExtent'].value)
        self.max_lat = self.latitude + self.long_extent
        self.min_lat = self.latitude - self.long_extent
        self.max_long = self.longitude + self.long_extent
        self.min_long = self.longitude - self.long_extent
        self.last_refresh = 0
        self.new_points_detected = False
        keep_in_zone = self.scenario.getElementsByTagName('KeepInZone')
        ############### KEEP IN ZONE ############################
        if len(keep_in_zone
               ) > 0:  # if there is a keep in zone then constraint to that
            self.latitude = float(keep_in_zone[0].getElementsByTagName(
                'Latitude')[0].childNodes[0].nodeValue)
            self.longitude = float(keep_in_zone[0].getElementsByTagName(
                'Longitude')[0].childNodes[0].nodeValue)
            width = float(keep_in_zone[0].getElementsByTagName('Width')
                          [0].childNodes[0].nodeValue)
            height = float(keep_in_zone[0].getElementsByTagName('Height')
                           [0].childNodes[0].nodeValue)
            centerPoint = Location3D()
            centerPoint.set_Latitude(self.latitude)
            centerPoint.set_Longitude(self.longitude)
            low_loc: Location3D = self.newLocation(centerPoint, height / -2,
                                                   width / -2)
            high_loc: Location3D = self.newLocation(centerPoint, height / 2,
                                                    width / 2)
            self.max_lat = max(low_loc.get_Latitude(), high_loc.get_Latitude())
            self.max_long = max(low_loc.get_Longitude(),
                                high_loc.get_Longitude())
            self.min_lat = min(low_loc.get_Latitude(), high_loc.get_Latitude())
            self.min_long = min(low_loc.get_Longitude(),
                                high_loc.get_Longitude())
        ################# SCORING TIMES ######################
        scoring_times = self.scenario.getElementsByTagName("Hack")
        self.force_recompute_times = []
        for scoring_time in scoring_times:
            time = float(scoring_time.attributes['Time'].value)
            self.force_recompute_times.append(time)

        ################FAKE POINTS
        if self.fake_point:
            self.fake_points()
        print('scenario loaded')

    def newLocation(self, loc: Location3D, dx, dy):
        R_EARTHKM = 6372.8
        latitude = loc.get_Latitude()
        longitude = loc.get_Longitude()
        new_latitude = latitude + (dy / (R_EARTHKM * 1000)) * (180 / math.pi)

        new_longitude = longitude + (dx / (R_EARTHKM * 1000)) * (
            180 / math.pi) / math.cos(latitude * math.pi / 180)

        new_location = Location3D()
        new_location.set_Latitude(new_latitude)
        new_location.set_Longitude(new_longitude)
        return new_location

    def fake_points(self):
        hazardZone_nodes = self.scenario.getElementsByTagName('HazardZone')
        for hazardZone_node in hazardZone_nodes:
            boundary_points = hazardZone_node.getElementsByTagName(
                'Location3D')
            for point_string in boundary_points:
                latitude = float(
                    point_string.getElementsByTagName('Latitude')
                    [0].childNodes[0].nodeValue)
                longitude = float(
                    point_string.getElementsByTagName('Longitude')
                    [0].childNodes[0].nodeValue)
                location = Location3D()
                location.set_Latitude(latitude)
                location.set_Longitude(longitude)
                lat, long = self.normalise_coordinates(location)
                try:
                    self.heatmap[lat][long] = 1.0
                    self.last_detected[lat][
                        long] = self.current_time  # the last registered time

                except Exception as ex:
                    print(ex)
        self.apply_smoothing()
        self.update_visdom()
        self.compute_and_send_estimate_hazardZone(True)

    def dataReceived(self, lmcpObject):
        try:
            if isinstance(lmcpObject, SessionStatus):
                print(f'time: {self.current_time} - session status'.ljust(100),
                      end='\r',
                      flush=True)
                session_status: SessionStatus = lmcpObject
                self.current_time = session_status.get_ScenarioTime(
                )  # save the last registered time to use in other parts of the code
                state: SimulationStatusType.SimulationStatusType = session_status.get_State(
                )
                if state is SimulationStatusType.SimulationStatusType.Reset:
                    self.viz.close(win=HEATMAP)
                    # self.viz.close(win=VIZ_SCATTER)
                    self.viz.close(win=CONTOUR)
                    self.viz.close(win="Trajectory")
                    self.scenario = None
                    self.filename = None  # scenario not ready
                    self.heatmap = np.zeros((SIZE_LAT, SIZE_LONG))
                    self.last_detected = np.zeros(
                        (SIZE_LAT, SIZE_LONG)
                    )  # the time at which the fire was last detected (or not) in that cell
                    self.smooth = np.zeros((SIZE_LAT, SIZE_LONG))
                    self.drones_status = {}
                    self.force_recompute_times = []
                    # self.belief_model = BeliefModel()
                    if len(session_status.get_Parameters()) > 0:
                        param: KeyValuePair
                        for param in session_status.get_Parameters():
                            if param.Key == b'source':
                                self.filename = param.Value.decode("utf-8")
                    if self.filename is not None:
                        self.load_scenario(self.filename)
                if self.filename is None:  # only move on when the scenario is ready
                    return
                self.current_time = session_status.ScenarioTime
                # self.heatmap = self.update_heatmap(delta_time)
                self.communication_channel.send(self.current_time,
                                                self.heatmap, self.max_lat,
                                                self.max_long, self.min_lat,
                                                self.min_long)
                self.compute_and_send_estimate_hazardZone()
            if isinstance(lmcpObject, AirVehicleState):
                # vehicleState: AirVehicleState = lmcpObject
                # id = vehicleState.ID
                # heading = vehicleState.Heading
                # location: Location3D = vehicleState.get_Location()
                # self.drones_status[id] = (heading, location)
                # try:
                #     locations = []
                #     y = []
                #     markers = []
                #     for key in self.drones_status:
                #         location: Location3D
                #         heading: int
                #         heading, location = self.drones_status[key]
                #         locations.append([location.get_Longitude(), location.get_Latitude()])
                #         y.append([1])
                #         heading = (360.0 - heading) % 360.0  # counterclockwise to clockwise
                #         markers.append((3, 0, heading))
                #     self.viz.scatter(X=np.array(locations), Y=np.array(y), win=VIZ_SCATTER, opts=dict(
                #         xtickmin=self.min_long,
                #         xtickmax=self.max_long,
                #         ytickmin=self.min_lat,
                #         ytickmax=self.max_lat,
                #         marker=markers,
                #         markersize=10,
                #         linestyle='None'
                #     ))
                # except BaseException as err:
                #     print('Skipped matplotlib example')
                #     print('Error message: ', err)
                #
                pass
            if isinstance(lmcpObject, HazardZoneDetection):
                print(
                    f'time: {self.current_time} - hazardzone detection'.ljust(
                        100),
                    end='\r',
                    flush=True)
                hazardDetected: HazardZoneDetection = lmcpObject
                # Get location where zone first detected
                new_point = False
                detectedLocation = hazardDetected.get_DetectedLocation()
                lat, long = self.normalise_coordinates(detectedLocation)
                detecting_id = hazardDetected.DetectingEnitiyID
                try:
                    if self.heatmap[lat][long] != 1.0:
                        self.heatmap[lat][long] = 1.0
                        self.last_detected[lat][
                            long] = self.current_time  # the last registered time
                        self.apply_smoothing()
                        self.new_points_detected = True
                        new_point = True

                except Exception as ex:
                    print(ex)
                # self.viz.contour(X=self.heatmap, win=self.contour, opts=dict(title='Contour plot'))
                if new_point:
                    self.update_visdom()

        except Exception as ex:
            print(ex)

    def update_visdom(self):
        self.viz.heatmap(X=self.heatmap,
                         win=HEATMAP,
                         opts=dict(title='Heatmap plot'))
        self.viz.contour(X=self.smooth,
                         win=CONTOUR,
                         opts=dict(title='Contour plot'))

    def apply_smoothing(self):
        self.smooth = ndimage.gaussian_filter(self.heatmap, 10)
        pass

    def normalise_coordinates(self, detectedLocation):
        lat = int((detectedLocation.get_Latitude() - self.min_lat) /
                  (self.max_lat - self.min_lat) * self.heatmap.shape[0])
        long = int((detectedLocation.get_Longitude() - self.min_long) /
                   (self.max_long - self.min_long) * self.heatmap.shape[1])
        return lat, long

    def denormalise_coordinates(self, lat, long):
        norm_lat = (lat * (self.max_lat - self.min_lat) /
                    self.heatmap.shape[0]) + self.min_lat
        norm_long = (long * (self.max_long - self.min_long) /
                     self.heatmap.shape[1]) + self.min_long
        return norm_lat, norm_long

    def update_heatmap(self, deltaTime):
        """
        Updates the heatmap and returns a new heatmap
        """

    '''gets the points in the heatmap where there is fire'''

    def compute_coords(self):
        coords = []
        for row in range(self.heatmap.shape[0]):
            for col in range(self.heatmap.shape[1]):
                if self.heatmap[row][
                        col] > 0.95:  # This could be a 1 check but we are pre-empting expanding this for decay.
                    coords.append((row, col))

        return coords

    def set_coord_as_hazard_zone(self, norm_poly):
        self.__estimatedHazardZone.get_BoundaryPoints().clear()
        for point in norm_poly.points:
            denormalised_point = Location3D()
            lat, long = self.denormalise_coordinates(point[0], point[1])
            denormalised_point.set_Latitude(lat)
            denormalised_point.set_Longitude(long)
            # print(denormalised_point)
            # point.set_Latitude(index.)
            self.__estimatedHazardZone.get_BoundaryPoints().append(
                denormalised_point)

    def compute_and_send_estimate_hazardZone(self, force=False):
        delta_time = self.current_time - self.last_refresh
        if (len(self.force_recompute_times) > 0 and self.current_time >
            (self.force_recompute_times[0] - 10) *
                1000):  #10 seconds before each scoring
            self.force_recompute_times.pop(0)  #remove first time
            force = True  #force recomputing
        if (force or (delta_time > REFRESH_RATE and self.new_points_detected)):
            self.last_refresh = self.current_time
            self.new_points_detected = False
            coords = self.compute_coords()

            # Simple triangle
            if len(coords) < 3:
                return

            try:
                # Different options to create polygon.
                norm_polys = calculate_polygons.calculate_polygons(coords)
                # norm_poly = ConvexHull(coords)
                # For now just get first polygon.
                for index, poly in enumerate(norm_polys):
                    # norm_poly = norm_polys[0]
                    # self.belief_model.polygons.append(poly)

                    self.set_coord_as_hazard_zone(poly)
                    self.sendEstimateReport(index)
            except Exception as ex:
                raise ex
                # print(ex)

    def sendEstimateReport(self, id=1):
        # Setting up the mission to send to the UAV
        hazardZoneEstimateReport = HazardZoneEstimateReport()
        hazardZoneEstimateReport.set_EstimatedZoneShape(
            self.__estimatedHazardZone)
        hazardZoneEstimateReport.set_UniqueTrackingID(id)
        hazardZoneEstimateReport.set_EstimatedGrowthRate(0)
        hazardZoneEstimateReport.set_PerceivedZoneType(HazardType.Fire)
        hazardZoneEstimateReport.set_EstimatedZoneDirection(0)
        hazardZoneEstimateReport.set_EstimatedZoneSpeed(0)

        # Sending the Vehicle Action Command message to AMASE to be interpreted
        self.__client.sendLMCPObject(hazardZoneEstimateReport)
コード例 #23
0
ファイル: demo.py プロジェクト: Garvit244/visdom
viz.bar(
    X=np.random.rand(20, 3),
    opts=dict(
        stacked=False,
        legend=['The Netherlands', 'France', 'United States']
    )
)

# histogram
viz.histogram(X=np.random.rand(10000), opts=dict(numbins=20))

# heatmap
viz.heatmap(
    X=np.outer(np.arange(1, 6), np.arange(1, 11)),
    opts=dict(
        columnnames=['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'],
        rownames=['y1', 'y2', 'y3', 'y4', 'y5'],
        colormap='Electric',
    )
)

# contour
x = np.tile(np.arange(1, 101), (100, 1))
y = x.transpose()
X = np.exp((((x - 50) ** 2) + ((y - 50) ** 2)) / -(20.0 ** 2))
viz.contour(X=X, opts=dict(colormap='Viridis'))

# surface
viz.surf(X=X, opts=dict(colormap='Hot'))

# line plots
viz.line(Y=np.random.rand(10))
コード例 #24
0
ファイル: main.py プロジェクト: zhaoforever/SANAS
def main(_run, nepochs, device, use_visdom, visdom_conf, n_classes,
         lambda_reward, r_beta, r_gamma, _config):
    exp_name = format_exp_name(_run._id, _config)
    if use_visdom:
        visdom_conf.update(env=exp_name)
        _run.info['visdom_server'] = "{server}:{port}/env/{env}".format(
            **visdom_conf)
    else:
        _run.info['visdom_server'] = "No visdom"

    _run.info['exp_name'] = exp_name
    front = _run.info['front'] = {}

    xp_logger = data_logger.Experiment(exp_name,
                                       use_visdom=use_visdom,
                                       visdom_opts=visdom_conf,
                                       time_indexing=False,
                                       xlabel='Epoch',
                                       log_git_hash=False)
    xp_logger.add_log_hook(_run.log_scalar)
    if use_visdom:
        xp_logger.plotter.windows_opts = defaultdict(
            lambda: dict(showlegend=True))

    viz = Visdom(**visdom_conf) if use_visdom else None

    # Dataset creation
    logger.info('### Dataset ###')

    ds, batch_first, class_w = create_dataset()
    _run.info['class_weights'] = class_w.tolist()

    confusion_matrix_opts = {
        'columnnames': ds['train'].dataset.ordered_class_names,
        'rownames': ds['train'].dataset.ordered_class_names
    }

    # Model Creation
    logger.info('### Model ###')

    adaptive_model = create_model()
    adaptive_model.loss = torch.nn.CrossEntropyLoss(weight=class_w,
                                                    reduction='none',
                                                    ignore_index=-7)

    path_recorder = PathRecorder(adaptive_model.stochastic_model)
    cost_evaluator = ComputationCostEvaluator(
        node_index=path_recorder.node_index, bw=False)
    # cost_evaluator = SimpleEdgeCostEvaluator(node_index=path_recorder.node_index, bw=False)

    cost_evaluator.init_costs(adaptive_model.stochastic_model)
    logger.info('Cost: {:.5E}'.format(cost_evaluator.total_cost))

    adaptive_model.to(device)

    # Optim Creation
    logger.info('### Optim ###')
    optimizer, schedulder = create_optim(
        params=adaptive_model.get_param_groups())

    # Check the param_groups order, to be sure to get the learning rates in the right order for logging
    assert [pg['name'] for pg in optimizer.param_groups
            ] == ['arch_params', 'pred_params']

    def optim_closure(loss):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Logger creation
    splits = ['train', 'validation', 'test']
    metrics = [
        'classif_loss', 'arch_loss', 'reward', 'lambda_reward',
        'silence_ratio', 'accuracy', 'average_cost', 'learning_rate_pred',
        'learning_rate_arch'
    ]

    for split in splits:
        xp_logger.ParentWrapper(tag=split,
                                name='parent'.format(split),
                                children=[
                                    xp_logger.SimpleMetric(name=metric)
                                    for metric in metrics
                                ])

    train_cost_loggers = dict(
        (i, xp_logger.AvgMetric(name='train_cost', tag=name))
        for i, name in enumerate(ds['train'].dataset.ordered_class_names))
    train_cost_loggers_perc = dict(
        (i, xp_logger.AvgMetric(name='train_cost_perceived', tag=name))
        for i, name in enumerate(ds['train'].dataset.ordered_class_names))

    node_names = adaptive_model.stochastic_model.ordered_node_names
    # entropy_loggers = [xp_logger.SimpleMetric(name='entropy', tag=name) for name in node_names]
    entropy_loggers = OrderedDict(
        (i, xp_logger.SimpleMetric(name='entropy_per_node', tag=name))
        for i, name in enumerate(node_names))
    # proba_loggers = [xp_logger.SimpleMetric(name='proba', tag=name) for name in node_names]
    proba_loggers = OrderedDict(
        (i, xp_logger.SimpleMetric(name='proba_per_node', tag=name))
        for i, name in enumerate(node_names))

    val_cost_loggers = dict(
        (i, xp_logger.AvgMetric(name='val_cost', tag=name))
        for i, name in enumerate(ds['validation'].dataset.ordered_class_names))
    val_cost_loggers_perc = dict(
        (i, xp_logger.AvgMetric(name='val_cost_perceived', tag=name))
        for i, name in enumerate(ds['validation'].dataset.ordered_class_names))

    test_cost_loggers = dict(
        (i, xp_logger.AvgMetric(name='test_cost', tag=name))
        for i, name in enumerate(ds['test'].dataset.ordered_class_names))
    test_cost_loggers_perc = dict(
        (i, xp_logger.AvgMetric(name='test_cost_perceived', tag=name))
        for i, name in enumerate(ds['test'].dataset.ordered_class_names))

    if use_visdom:
        print_properties(viz, _config)
        print_properties(viz, _run.info)

    ema_reward = EMA(r_beta)  # Init the exponential moving average
    for n in range(1, nepochs + 1):
        logger.info('### Sarting epoch n°{} ### {}'.format(
            n, _run.info['visdom_server']))
        logger.info(' '.join(sys.argv))

        if schedulder:
            schedulder.step(n)
            arch_lr, pred_lr = schedulder.get_lr()
            xp_logger.Parent_Train.update(learning_rate_pred=pred_lr,
                                          learning_rate_arch=arch_lr)

        # Training
        adaptive_model.train()
        train_cm, train_costcm, train_costcm_norm, train_cost_per_step, logs, train_cost_per_signal_level, train_stats = evaluate_model(
            adaptive_model,
            ds['train'],
            batch_first,
            device,
            path_recorder,
            cost_evaluator,
            train_cost_loggers,
            train_cost_loggers_perc,
            n_classes,
            lambda_reward,
            ema_reward,
            r_gamma,
            optim_closure,
            name='Train')

        xp_logger.Parent_Train.update(**dict(
            (k, v.value()[0]) for k, v in logs.items()))

        for node_idx, ent in train_stats['en'].items():
            entropy_loggers[node_idx].update(ent.value()[0])

        for node_idx, prob in train_stats['pn'].items():
            proba_loggers[node_idx].update(prob.value()[0])

        # Evaluation
        adaptive_model.eval()
        val_cm, val_costcm, val_costcm_norm, val_cost_per_step, logs, val_cost_per_signal_level, val_stats = evaluate_model(
            adaptive_model,
            ds['validation'],
            batch_first,
            device,
            path_recorder,
            cost_evaluator,
            val_cost_loggers,
            val_cost_loggers_perc,
            n_classes,
            lambda_reward,
            ema_reward,
            r_gamma,
            name='Validation')

        xp_logger.Parent_Validation.update(**dict(
            (k, v.value()[0]) for k, v in logs.items()))

        test_cm, test_costcm, test_costcm_norm, test_cost_per_step, logs, test_cost_per_signal_level, test_stats = evaluate_model(
            adaptive_model,
            ds['test'],
            batch_first,
            device,
            path_recorder,
            cost_evaluator,
            test_cost_loggers,
            test_cost_loggers_perc,
            n_classes,
            lambda_reward,
            ema_reward,
            r_gamma,
            name='Test')
        xp_logger.Parent_Test.update(**dict(
            (k, v.value()[0]) for k, v in logs.items()))

        if use_visdom:
            # Log
            plot_(viz,
                  train_stats['es'],
                  node_names,
                  f'Entropy per step {n} - Train',
                  'train_eps',
                  log_func=_run.log_scalar)
            plot_(viz,
                  train_stats['ps'],
                  node_names,
                  f'Probability per step {n} - Train',
                  'train_pps',
                  log_func=_run.log_scalar)
            try:
                viz.heatmap(train_cm,
                            win='train_cm',
                            opts={
                                **confusion_matrix_opts, 'title':
                                'Train Confusion matrix'
                            })
                viz.heatmap(val_cm,
                            win='val_cm',
                            opts={
                                **confusion_matrix_opts, 'title':
                                'Val Confusion matrix'
                            })
                viz.heatmap(test_cm,
                            win='test_cm',
                            opts={
                                **confusion_matrix_opts, 'title':
                                'Test Confusion matrix'
                            })

                # viz.heatmap(train_costcm, win='train_cost_matrix',
                #             opts={**confusion_matrix_opts, 'title': 'Train cost matrix'})
                # viz.heatmap(val_costcm, win='val_cost_matrix', opts={**confusion_matrix_opts, 'title': 'Val cost matrix'})
                # viz.heatmap(test_costcm, win='test_cost_matrix',
                #             opts={**confusion_matrix_opts, 'title': 'Test cost matrix'})

                viz.heatmap(train_costcm_norm,
                            win='train_cost_matrix_norm',
                            opts={
                                **confusion_matrix_opts, 'title':
                                'Train cost matrix Normalized'
                            })
                viz.heatmap(val_costcm_norm,
                            win='val_cost_matrix_norm',
                            opts={
                                **confusion_matrix_opts, 'title':
                                'Val cost matrix Normalized'
                            })
                viz.heatmap(test_costcm_norm,
                            win='test_cost_matrix_norm',
                            opts={
                                **confusion_matrix_opts, 'title':
                                'Test cost matrix Normalized'
                            })

            except ConnectionError as err:
                logger.warning('Error in heatmaps:')
                logger.warning(err)
                traceback.print_exc()

            plot_meters(viz,
                        train_cost_per_step,
                        'train_cps',
                        'Cost per step {}'.format(n),
                        win='cps',
                        log_func=_run.log_scalar)
            plot_meters(viz,
                        val_cost_per_step,
                        'val_cps',
                        win='cps',
                        log_func=_run.log_scalar)
            plot_meters(viz,
                        test_cost_per_step,
                        'test_cps',
                        win='cps',
                        log_func=_run.log_scalar)

            plot_meters(viz,
                        train_cost_per_signal_level,
                        'cost/sig_train',
                        'Cost per signal {}'.format(n),
                        win='cpsig',
                        error_bars=False,
                        log_func=_run.log_scalar)
            plot_meters(viz,
                        val_cost_per_signal_level,
                        'cost/sig_val',
                        win='cpsig',
                        error_bars=False,
                        log_func=_run.log_scalar)
            plot_meters(viz,
                        test_cost_per_signal_level,
                        'cost/sig_test',
                        win='cpsig',
                        error_bars=False,
                        log_func=_run.log_scalar)

        xp_logger.log_with_tag(tag='*', reset=True)

        msg = 'Losses: {:.3f}({:.3E})-{:.3f}-{:.3f}, Accuracies: {:.3f}-{:.3f}-{:.3f}, Avg cost: {:.3E}-{:.3E}-{:.3E}'
        msg = msg.format(xp_logger.classif_loss_train, xp_logger.reward_train,
                         xp_logger.classif_loss_validation,
                         xp_logger.classif_loss_test, xp_logger.accuracy_train,
                         xp_logger.accuracy_validation,
                         xp_logger.accuracy_test, xp_logger.average_cost_train,
                         xp_logger.average_cost_validation,
                         xp_logger.average_cost_test)
        logger.info(msg)

        pareto_data = {
            'cost': xp_logger.logged['average_cost_validation'].values(),
            'acc': xp_logger.logged['accuracy_validation'].values(),
            '_orig_': xp_logger.logged['average_cost_validation'].keys()
        }

        pareto = paretize_exp(pareto_data, x_name='cost', crit_name='acc')

        if n in pareto['_orig_']:
            logger.info('New on front !')
            front.update(**pareto)
            save_checkpoint(adaptive_model, ex, n)
        elif n > 0 and n % 50 == 0:
            logger.info('Checkpointing')
            save_checkpoint(adaptive_model, ex, n)

        logger.info(pareto['_orig_'])
        best_epoch = pareto['_orig_'][-1]
        logger.info('Best \tVal: {:.3f} - Test: {:.3f} (Epoch {})\n'.format(
            xp_logger.logged['accuracy_validation'][best_epoch],
            xp_logger.logged['accuracy_test'][best_epoch], best_epoch))
コード例 #25
0
ファイル: utils.py プロジェクト: MkuuWaUjinga/DeepMDP-SSL4RL
class Visualizer:

    # TODO plot episode lengths
    # TODO plot distribution over chosen actions.

    def __init__(self, experiment_id, plot_list, port=9098):
        self.port = 9098
        self.plot_list = plot_list
        self.viz = Visdom(port=port)
        self.env = experiment_id
        self.line_plotter = VisdomLinePlotter(self.viz, env_name=experiment_id)
        self.correlation_plot_window = None
        self.aux_losses = defaultdict(list)
        self.correlation_matrix = None
        self.num_calls = 0
        self.store_every_th = 10
        self.count_correlation_matrix = 0  # Can be calculated from num_calls and store_every_th

    def publish_config(self, config):
        config_string = pprint.pformat(dict(config)).replace("\n",
                                                             "<br>").replace(
                                                                 " ", "&nbsp;")
        self.viz.text(config_string, env=self.env)

    def visualize_episodical_stats(self, algo, num_new_episodes):
        if self.make_weights_plot():
            self.visualize_weights(algo, num_new_episodes)
        if self.visualize_aux():
            self.visualize_aux_losses(num_new_episodes,
                                      len(algo.episode_rewards))
        if self.visualize_latent_space():
            self.visualize_latent_space_correlation(num_new_episodes,
                                                    len(algo.episode_rewards),
                                                    algo.experiment_id)
        if self.visualize_stats():
            for i in range(
                    len(algo.episode_rewards) - num_new_episodes,
                    len(algo.episode_rewards)):
                self.line_plotter.plot("episode reward", "rewards",
                                       "Rewards per episode", i,
                                       algo.episode_rewards[i])
                self.line_plotter.plot("episode mean q-values", "q-values",
                                       "Mean q-values per episode", i,
                                       algo.episode_mean_q_vals[i])
                self.line_plotter.plot("episode std q-values", "q-std",
                                       "Std of q-values per episode", i,
                                       algo.episode_std_q_vals[i])
                # Plot running average of rewards
                if i > 100:
                    self.line_plotter.plot(
                        "episode reward",
                        "avg reward",
                        "Rewards per episode",
                        i,
                        np.mean(algo.episode_rewards[i - 100:i]),
                        color=np.array([
                            [0, 0, 128],
                        ]))

    def visualize_module(self, head, head_name, num_episodes,
                         num_new_episodes):
        for x, params in enumerate(head.parameters()):
            l2_norm = params.data.norm(p=2).cpu().numpy()
            min = torch.min(params.data).cpu().numpy()
            max = torch.max(params.data).cpu().numpy()
            mean = torch.mean(params.data).cpu().numpy()
            for i in range(num_new_episodes):
                self.line_plotter.plot(
                    f"metrics {head_name} {x}", f"L2-norm",
                    f"Weights {head_name} {list(params.shape)}",
                    num_episodes - num_new_episodes + i, l2_norm)
                self.line_plotter.plot(
                    f"metrics {head_name} {x}",
                    f"min",
                    f"Weights of {head_name} {list(params.shape)}",
                    num_episodes - num_new_episodes + i,
                    min,
                    color=np.array([
                        [0, 0, 128],
                    ]))
                self.line_plotter.plot(
                    f"metrics {head_name} {x}",
                    f"max",
                    f"Weights of {head_name} {list(params.shape)}",
                    num_episodes - num_new_episodes + i,
                    max,
                    color=np.array([
                        [128, 0, 0],
                    ]))
                self.line_plotter.plot(
                    f"metrics {head_name} {x}",
                    f"mean",
                    f"Weights of {head_name} {list(params.shape)}",
                    num_episodes - num_new_episodes + i,
                    mean,
                    color=np.array([
                        [0, 128, 0],
                    ]))

    def visualize_weights(self, algo, num_new_episodes):
        num_episodes = len(algo.episode_rewards)
        self.line_plotter.env = algo.experiment_id + "_weights"
        self.visualize_module(algo.qf.head, "Q-head", num_episodes,
                              num_new_episodes)
        self.visualize_module(algo.qf.encoder, "Encoder", num_episodes,
                              num_new_episodes)
        for aux in algo.auxiliary_objectives:
            self.visualize_module(aux.net, aux.__class__.__name__,
                                  num_episodes, num_new_episodes)
        self.line_plotter.env = algo.experiment_id + "_main"

    def make_weights_plot(self):
        return "weight_plot" in self.plot_list

    def visualize_aux(self):
        return "aux_loss_plot" in self.plot_list

    def visualize_latent_space(self):
        return "latent_space_correlation_plot" in self.plot_list

    def visualize_stats(self):
        return "episodical_stats" in self.plot_list

    def save_aux_loss(self, loss, loss_type):
        if self.visualize_aux():
            self.aux_losses[loss_type].append(loss)

    def visualize_aux_losses(self, num_new_episodes, total_num_episode):
        if self.aux_losses and num_new_episodes > 0:
            for aux_loss in self.aux_losses:
                for i in range(num_new_episodes):
                    self.line_plotter.plot(
                        aux_loss, "mean", aux_loss,
                        total_num_episode - num_new_episodes + i,
                        np.mean(self.aux_losses[aux_loss]))
                    self.line_plotter.plot(
                        aux_loss,
                        "median",
                        aux_loss,
                        total_num_episode - num_new_episodes + i,
                        np.median(self.aux_losses[aux_loss]),
                        color=np.array([
                            [0, 0, 128],
                        ]))
            self.aux_losses = defaultdict(list)

    def save_latent_space(self, algo, next_obs, ground_truth_embedding):
        if self.visualize_latent_space(
        ) and self.num_calls % self.store_every_th == 0:
            if ground_truth_embedding is None:
                raise ValueError(
                    "Ground truth embedding mustn't be of None type")
            ground_truth_embedding = ground_truth_embedding.to(device)
            algo.qf.eval()
            with torch.no_grad():
                _, embedding = algo.qf(next_obs, return_embedding=True)
            algo.qf.train()
            assert embedding.size() == ground_truth_embedding.size()
            if self.correlation_matrix is None:
                embedding_dim = embedding.size(1)
                self.correlation_matrix = torch.zeros(
                    (embedding_dim, embedding_dim)).to(device)
            # Calculate correlation
            self.correlation_matrix += self.calculate_correlation(
                embedding.t(), ground_truth_embedding.t())
            self.count_correlation_matrix += 1
        self.num_calls += 1

    def visualize_latent_space_correlation(self, num_new_episodes,
                                           total_num_episodes, experiment_id):
        if self.correlation_matrix is not None and num_new_episodes > 0:
            self.correlation_matrix = self.correlation_matrix.div(
                self.count_correlation_matrix)
            assert round(torch.max(
                self.correlation_matrix).item(), 2) <= 1.0 and round(
                    torch.min(self.correlation_matrix).item(),
                    2) >= -1.0, "Invalid value for correlation coefficient!"
            self.line_plotter.env = experiment_id + "_latent_space"
            column_names = [
                "pos_x", "pos_y", "vel_x", "vel_y", "ang", "ang_vel", "leg_1",
                "leg_2"
            ]
            row_names = ['l1', 'l2', 'l3', 'l4', 'l5', 'l6', 'l7', 'l8']
            self.correlation_plot_window = self.viz.heatmap(
                X=self.correlation_matrix,
                env=self.env,
                win=self.correlation_plot_window,
                opts=dict(
                    columnnames=column_names,
                    rownames=row_names,
                    colormap='Viridis',
                    xmin=-1.0,
                    xmax=1.0,
                    title=
                    "Average latent space correlation per batch and episode"))
            for i, column_name in enumerate(column_names):
                for j, row_name in enumerate(row_names):
                    for k in range(num_new_episodes):
                        self.line_plotter.plot(
                            column_name + "_correlation",
                            row_name,
                            column_name,
                            total_num_episodes - num_new_episodes + k,
                            self.correlation_matrix[j, i].cpu().numpy(),
                            color=np.array([
                                [
                                    int((255 / 8) * j),
                                    int((255 / 8) * (8 - j)), 0
                                ],
                            ]))
            self.line_plotter.env = experiment_id + "_main"
            self.correlation_matrix = None
            self.count_correlation_matrix = 0

    @staticmethod
    def calculate_correlation(x1, x2):
        """
        takes two 2D tensors of (latent_space_size, sample_size) and calculates the column-wise correlation between the
        two
        :param x1:
        :param x2:
        :return: a 2D tensor of shape (latent_space_size, latent_space_size)
        """
        with torch.no_grad():
            # Calculate covariance matrix of columns
            mean_x1 = torch.mean(x1, 1).unsqueeze(1)
            mean_x2 = torch.mean(x2, 1).unsqueeze(1)
            x1m = x1.sub(mean_x1)
            x2m = x2.sub(mean_x2)
            c = x1m.mm(x2m.t())
            c = c / (x1.size(1) - 1)
            # Normalize by standard deviations. Add epsilon for numerical stability if std close to 0
            epsilon = 1e-9
            std_x1 = torch.std(x1, 1).unsqueeze(1) + epsilon
            std_x2 = torch.std(x2, 1).unsqueeze(1) + epsilon
            c = c.div(std_x1)
            c = c.div(std_x2.t())

            assert round(torch.max(c).item(), 2) <= 1.0 and round(
                torch.min(c).item(),
                2) >= -1.0, "Invalid value for correlation coefficient!"
            return c
コード例 #26
0
def run(args):
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # Load data
    train_data, test_data, train_mask, test_mask, user_list = load_data(random_split=True)
    # train_data, test_data, train_mask, test_mask, user_list = load_toy_data()

    # Params
    # n_bins = 288
    n_samples, n_bins, n_mods = train_data.shape
    n_features = n_bins * n_mods
    # n_mods = n_features // n_bins
    modalities = ['cpm', 'steps', 'screen', 'location_lat', 'location_lon'][:n_mods]
    num_train = train_data.shape[0] // args.batch_size
    num_test = test_data.shape[0] // args.batch_size

    # Convert to torch tensor
    train_data = torch.from_numpy(train_data)
    test_data = torch.from_numpy(test_data)
    train_mask = torch.from_numpy(train_mask).float()
    test_mask = torch.from_numpy(test_mask).float()

    def get_batch(source, mask, i, evaluation=False):
        data = Variable(source[i * args.batch_size:(i + 1) * args.batch_size], volatile=evaluation)
        _mask = Variable(mask[i * args.batch_size:(i + 1) * args.batch_size], volatile=evaluation)
        return data, _mask

    if args.model.lower() == 'vae':
        model = VAE(args.layers, input_dim=n_features, args=args)
    elif args.model.lower() == 'rae':
        model = CRAE(args.layers, input_dim=n_features, args=args)
    elif args.model.lower() == 'unet':
        model = SUnet(args.layers, input_dim=n_features, args=args)
    elif args.model.lower() == 'avb':
        model = AVB(args.layers, input_dim=n_features, args=args)
    else:
        model = SDAE(args.layers, input_dim=n_features, args=args)
    print(model)

    def train(epoch):
        model.train()
        train_loss = 0
        for batch_idx in range(num_train):
            data, mask = get_batch(train_data, train_mask, batch_idx, evaluation=False)

            if args.cuda:
                data = data.cuda()

            # Run model updates and collect loss
            loss = model.forward(data, mask)
            train_loss += loss

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_data),
                           100. * batch_idx / num_train,
                           loss / len(data)))

        print('====> Epoch: {} Average loss: {:.6f}'.format(
            epoch, train_loss / len(train_data)))
        return train_loss / len(train_data)

    def test(epoch):
        model.eval()
        test_loss = 0
        for batch_idx in range(num_test):
            data, mask = get_batch(test_data, test_mask, batch_idx, evaluation=True)
            if args.cuda:
                data = data.cuda()

            # Evaluate batch on model
            test_loss += model.eval_loss(data, mask)

        test_loss /= len(test_data)
        print('====> Test set loss: {:.6f}'.format(test_loss))
        return test_loss

    train_loss = list()
    test_loss = list()
    for epoch in range(1, args.epochs + 1):
        train_loss.append(train(epoch))
        test_loss.append(test(epoch))

    # Plot result
    test_batch, test_mask_batch = get_batch(test_data, test_mask, 0, evaluation=True)

    if 'vae' in args.model:
        recon_batch, mu, logvar, noise = model(test_batch, test_mask_batch)
    else:
        recon_batch, noise = model(test_batch, test_mask_batch)

    # Mask out known values
    test_batch = test_batch * test_mask_batch
    recon_batch = recon_batch * test_mask_batch  # * (1 - noise)

    test_batch = test_batch.data.numpy().reshape(-1, n_bins, n_mods)
    recon_batch = recon_batch.data.numpy().reshape(-1, n_bins, n_mods)

    # fig, ax = plt.subplots(nrows=2, ncols=n_mods, figsize=(10 * n_mods, 20))
    # for i, mod in enumerate(modalities):
    #     vmax = np.max((test_batch[:, :, i].max(), recon_batch[:, :, i].max()))
    #     sns.heatmap(test_batch[:, :, i], ax=ax[0, i], vmin=0, vmax=vmax)
    #     sns.heatmap(recon_batch[:, :, i], ax=ax[1, i], vmin=0, vmax=vmax)
    # plt.savefig('{}_recon_heatmap'.format(args.model))
    #
    # # Plot error curves
    # fig, ax = plt.subplots(figsize=(20, 10))
    # ax.plot(range(args.epochs - 1), train_loss[1:], label='train')
    # ax.plot(range(args.epochs - 1), test_loss[1:], label='test')
    # plt.savefig('{}_error'.format(args.model))

    # Create a visdom object
    vis = Visdom(env=args.model)

    # Heatmap
    for i, mod in enumerate(modalities):
        vmax = np.max((test_batch[:, :, i].max(), recon_batch[:, :, i].max()))
        vis.heatmap(test_batch[:, :, i],
                    opts=dict(colormap='Electric', title='true_' + mod, xmin=0, xmax=float(vmax)))
        vis.heatmap(recon_batch[:, :, i],
                    opts=dict(colormap='Electric', title='recon_' + mod, xmin=0, xmax=float(vmax)))
    vis.heatmap(((1 - noise) * test_mask_batch)[:, :, 0].data.numpy(), opts=dict(title='mask'))

    # Errors
    vis.line(np.stack((train_loss[1:], test_loss[1:]), axis=1),
             np.tile(np.arange(args.epochs - 1), (2, 1)).transpose(),
             opts=dict(legend=['train', 'test']))

    return train_loss[-1], test_loss[-1]
コード例 #27
0
ファイル: callbacks.py プロジェクト: RaRe-Technologies/gensim
class Callback(object):
    """A class representing routines called reactively at specific phases during trained.

    These can be used to log or visualize the training progress using any of the metric scores developed before.
    The values are stored at the end of each training epoch. The following metric scores are currently available:

        * :class:`~gensim.models.callbacks.CoherenceMetric`
        * :class:`~gensim.models.callbacks.PerplexityMetric`
        * :class:`~gensim.models.callbacks.DiffMetric`
        * :class:`~gensim.models.callbacks.ConvergenceMetric`

    """
    def __init__(self, metrics):
        """

        Parameters
        ----------
        metrics : list of :class:`~gensim.models.callbacks.Metric`
            The list of metrics to be reported by the callback.

        """
        self.metrics = metrics

    def set_model(self, model):
        """Save the model instance and initialize any required variables which would be updated throughout training.

        Parameters
        ----------
        model : :class:`~gensim.models.basemodel.BaseTopicModel`
            The model for which the training will be reported (logged or visualized) by the callback.

        """
        self.model = model
        self.previous = None
        # check for any metric which need model state from previous epoch
        if any(isinstance(metric, (DiffMetric, ConvergenceMetric)) for metric in self.metrics):
            self.previous = copy.deepcopy(model)
            # store diff diagonals of previous epochs
            self.diff_mat = Queue()
        if any(metric.logger == "visdom" for metric in self.metrics):
            if not VISDOM_INSTALLED:
                raise ImportError("Please install Visdom for visualization")
            self.viz = Visdom()
            # store initial plot windows of every metric (same window will be updated with increasing epochs)
            self.windows = []
        if any(metric.logger == "shell" for metric in self.metrics):
            # set logger for current topic model
            self.log_type = logging.getLogger('gensim.models.ldamodel')

    def on_epoch_end(self, epoch, topics=None):
        """Report the current epoch's metric value.

        Called at the end of each training iteration.

        Parameters
        ----------
        epoch : int
            The epoch that just ended.
        topics : list of list of str, optional
            List of tokenized topics. This is required for the coherence metric.

        Returns
        -------
        dict of (str, object)
            Mapping from metric names to their values. The type of each value depends on the metric type,
            for example :class:`~gensim.models.callbacks.DiffMetric` computes a matrix while
            :class:`~gensim.models.callbacks.ConvergenceMetric` computes a float.

        """
        # stores current epoch's metric values
        current_metrics = {}

        # plot all metrics in current epoch
        for i, metric in enumerate(self.metrics):
            label = str(metric)
            value = metric.get_value(topics=topics, model=self.model, other_model=self.previous)

            current_metrics[label] = value

            if metric.logger == "visdom":
                if epoch == 0:
                    if value.ndim > 0:
                        diff_mat = np.array([value])
                        viz_metric = self.viz.heatmap(
                            X=diff_mat.T, env=metric.viz_env, opts=dict(xlabel='Epochs', ylabel=label, title=label)
                        )
                        # store current epoch's diff diagonal
                        self.diff_mat.put(diff_mat)
                        # saving initial plot window
                        self.windows.append(copy.deepcopy(viz_metric))
                    else:
                        viz_metric = self.viz.line(
                            Y=np.array([value]), X=np.array([epoch]), env=metric.viz_env,
                            opts=dict(xlabel='Epochs', ylabel=label, title=label)
                        )
                        # saving initial plot window
                        self.windows.append(copy.deepcopy(viz_metric))
                else:
                    if value.ndim > 0:
                        # concatenate with previous epoch's diff diagonals
                        diff_mat = np.concatenate((self.diff_mat.get(), np.array([value])))
                        self.viz.heatmap(
                            X=diff_mat.T, env=metric.viz_env, win=self.windows[i],
                            opts=dict(xlabel='Epochs', ylabel=label, title=label)
                        )
                        self.diff_mat.put(diff_mat)
                    else:
                        self.viz.line(
                            Y=np.array([value]),
                            X=np.array([epoch]),
                            env=metric.viz_env,
                            win=self.windows[i],
                            update='append'
                        )

            if metric.logger == "shell":
                statement = "".join(("Epoch ", str(epoch), ": ", label, " estimate: ", str(value)))
                self.log_type.info(statement)

        # check for any metric which need model state from previous epoch
        if any(isinstance(metric, (DiffMetric, ConvergenceMetric)) for metric in self.metrics):
            self.previous = copy.deepcopy(self.model)

        return current_metrics
コード例 #28
0
ファイル: test.py プロジェクト: tamirtrack/soccerontable
    if opt.input_nc > 3:
        input = torch.cat((input, mask), 1)

    if opt.cuda:
        input = input.cuda()
        target = target.cuda()

    output = netG(input)
    final_prediction = logsoftmax(output[-1])

    img, prediction, label, mask = convert_test_prediction(
        input, mask, target, final_prediction)

    viz.image(img.transpose(2, 0, 1) * mask[0, :, :, :],
              win=win0,
              opts=dict(title='Testing {0}: Input image'.format(opt.dataset)))
    viz.heatmap(
        prediction[::-1, :],
        win=win1,
        opts=dict(title='Testing {0}: Depth estimation'.format(opt.dataset)))
    viz.heatmap(label[::-1, :],
                win=win2,
                opts=dict(title='Testing {0}: Label'.format(opt.dataset)))

    # Save predictions
    fname = testing_data_loader.dataset.image_filenames[iteration]
    basename, ext = file_utils.extract_basename(fname)
    np.save(join(path_to_data, 'predictions', basename),
            final_prediction.cpu().data.numpy())
コード例 #29
0
ファイル: plotting.py プロジェクト: amagooda/siatl
class Visualizer:
    def __init__(self,
                 env="main",
                 server="http://localhost",
                 port=8097,
                 base_url="/",
                 http_proxy_host=None,
                 http_proxy_port=None):
        self._viz = Visdom(env=env,
                           server=server,
                           port=port,
                           http_proxy_host=http_proxy_host,
                           http_proxy_port=http_proxy_port,
                           use_incoming_socket=False)
        self._viz.close(env=env)

    def plot_line(self, values, steps, name, legend=None):
        if legend is None:
            opts = dict(title=name)
        else:
            opts = dict(title=name, legend=legend)

        self._viz.line(X=numpy.column_stack(steps),
                       Y=numpy.column_stack(values),
                       win=name,
                       update='append',
                       opts=opts)

    def plot_text(self, text, title, pre=True):
        _width = max([len(x) for x in text.split("\n")]) * 10
        _heigth = len(text.split("\n")) * 20
        _heigth = max(_heigth, 120)
        if pre:
            text = "<pre>{}</pre>".format(text)

        self._viz.text(text,
                       win=title,
                       opts=dict(title=title,
                                 width=min(_width, 400),
                                 height=min(_heigth, 400)))

    def plot_bar(self, data, labels, title):
        self._viz.bar(win=title,
                      X=data,
                      opts=dict(legend=labels, stacked=False, title=title))

    def plot_scatter(self, data, labels, title):
        X = numpy.concatenate(data, axis=0)
        Y = numpy.concatenate(
            [numpy.full(len(d), i) for i, d in enumerate(data, 1)], axis=0)
        self._viz.scatter(win=title,
                          X=X,
                          Y=Y,
                          opts=dict(legend=labels,
                                    title=title,
                                    markersize=5,
                                    webgl=True,
                                    width=400,
                                    height=400,
                                    markeropacity=0.5))

    def plot_heatmap(self, data, labels, title):
        self._viz.heatmap(
            win=title,
            X=data,
            opts=dict(
                title=title,
                columnnames=labels[1],
                rownames=labels[0],
                width=700,
                height=700,
                layoutopts={
                    'plotly': {
                        'xaxis': {
                            'side': 'top',
                            'tickangle': -60,
                            # 'autorange': "reversed"
                        },
                        'yaxis': {
                            'autorange': "reversed"
                        },
                    }
                }))
コード例 #30
0
                if target_str == actual_str:
                    num_correct += 1
            print("Accuracy: %f" % (num_correct / output.shape[0]))
            last_100_losses = []

        if summarize and rnn.debug:
            loss = np.mean(last_save_losses)
            last_save_losses = []

            print(random_length)
            if (epoch / summarize_freq) % 3 == 0:
                viz.close(None)
            viz.heatmap(
                mhx['memory'][0],
                opts=dict(xtickstep=10,
                          ytickstep=2,
                          title='Memory, t: ' + str(epoch) + ', num: ' +
                          q_and_a_to_string(input_data[0], output[0]),
                          ylabel='layer * time',
                          xlabel='mem_slot * mem_size'))

            all_read_weights = torch.stack(
                [x['read_weights'][0] for x in all_mems])
            viz.heatmap(
                all_read_weights.squeeze(),
                opts=dict(xtickstep=10,
                          ytickstep=2,
                          title='Read Weights, t: ' + str(epoch) + ', num: ' +
                          q_and_a_to_string(input_data[0], output[0]),
                          ylabel='layer * time',
                          xlabel='nr_read_heads * mem_slot'))
コード例 #31
0
    viz.bar(X=np.abs(np.random.rand(5, 3)),
            opts=dict(stacked=True,
                      legend=['Facebook', 'Google', 'Twitter'],
                      rownames=['2012', '2013', '2014', '2015', '2016']))
    viz.bar(X=np.random.rand(20, 3),
            opts=dict(stacked=False,
                      legend=['The Netherlands', 'France', 'United States']))

    # histogram
    viz.histogram(X=np.random.rand(10000), opts=dict(numbins=20))

    # heatmap
    viz.heatmap(
        X=np.outer(np.arange(1, 6), np.arange(1, 11)),
        opts=dict(
            columnnames=['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'],
            rownames=['y1', 'y2', 'y3', 'y4', 'y5'],
            colormap='Electric',
        ))

    # contour
    x = np.tile(np.arange(1, 101), (100, 1))
    y = x.transpose()
    X = np.exp((((x - 50)**2) + ((y - 50)**2)) / -(20.0**2))
    viz.contour(X=X, opts=dict(colormap='Viridis'))

    # surface
    viz.surf(X=X, opts=dict(colormap='Hot'))

    # line plots
    viz.line(Y=np.random.rand(10), opts=dict(showlegend=True))
コード例 #32
0
ファイル: loggers.py プロジェクト: mikekestemont/seqmod
class VisdomLogger(Logger):
    """
    Logger that uses visdom to create learning curves

    Parameters:
    ===========
    - env: str, name of the visdom environment
    - log_checkpoints: bool, whether to use checkpoints or epoch averages
        for training loss
    - legend: tuple, names of the different losses that will be plotted.
    """
    def __init__(self,
                 env=None,
                 log_checkpoints=True,
                 losses=('loss', ),
                 phases=('train', 'valid'),
                 server='http://localhost',
                 port=8097,
                 max_y=None,
                 **opts):
        self.viz = None
        if Visdom is not None:
            self.viz = Visdom(server=server, port=port, env=env)
        self.legend = ['%s.%s' % (p, l) for p in phases for l in losses]
        opts.update({'legend': self.legend})
        self.opts = opts
        self.env = env
        self.max_y = max_y
        self.log_checkpoints = log_checkpoints
        self.losses = set(losses)
        self.last = {p: {l: None for l in losses} for p in phases}
        self.pane = self._init_pane()

    @skip_on_import_error(Visdom)
    def _init_pane(self):
        nan = np.array([np.NAN, np.NAN])
        X = np.column_stack([nan] * len(self.legend))
        Y = np.column_stack([nan] * len(self.legend))
        return self.viz.line(
            X=X, Y=Y, env=self.env, opts=self.opts)

    def _update_last(self, epoch, loss, phase, loss_label):
        self.last[phase][loss_label] = {'X': epoch, 'Y': loss}

    def _plot_line(self, X, Y, phase, loss_label):
        name = "%s.%s" % (phase, loss_label)
        X = np.array([self.last[phase][loss_label]['X'], X])
        Y = np.array([self.last[phase][loss_label]['Y'], Y])
        if self.max_y:
            Y = np.clip(Y, Y.min(), self.max_y)
        self.viz.updateTrace(
            X=X, Y=Y, name=name, append=True, win=self.pane, env=self.env)

    def _plot_payload(self, epoch, losses, phase):
        for label, loss in losses.items():
            if label not in self.losses:
                continue
            if self.last[phase][label] is not None:
                self._plot_line(epoch, loss, phase=phase, loss_label=label)
            self._update_last(epoch, loss, phase, label)

    @skip_on_import_error(Visdom)
    def epoch_end(self, payload):
        if self.log_checkpoints:
            # only use epoch end if checkpoint isn't being used
            return
        losses, epoch = payload['loss'], payload['epoch'] + 1
        self._plot_payload(epoch, losses, 'train')

    @skip_on_import_error(Visdom)
    def validation_end(self, payload):
        losses, epoch = payload['loss'], payload['epoch'] + 1
        self._plot_payload(epoch, losses, 'valid')

    @skip_on_import_error(Visdom)
    def checkpoint(self, payload):
        epoch = payload['epoch'] + payload["batch"] / payload["total_batches"]
        losses = payload['loss']
        self._plot_payload(epoch, losses, 'train')

    @skip_on_import_error(Visdom)
    def attention(self, payload):
        title = "epoch {epoch}/ batch {batch_num}".format(**payload)
        if 'title' in self.opts:
            title = self.opts['title'] + ": " + title
        self.viz.heatmap(
            X=np.array(payload["att"]),
            env=self.env,
            opts={'rownames': payload["hyp"],
                  'columnnames': payload["target"],
                  'title': title})
コード例 #33
0
class Logger():
    def __init__(self, n_epochs, batches_epoch):
        self.viz = Visdom()
        self.n_epochs = n_epochs
        self.batches_epoch = batches_epoch
        self.epoch = 1
        self.batch = 1
        self.prev_time = time.time()
        self.mean_period = 0
        self.losses = {}
        self.loss_windows = {}
        self.image_windows = {}

    def log(self, losses=None, images=None):
        self.mean_period += (time.time() - self.prev_time)
        self.prev_time = time.time()

        sys.stdout.write(
            '\rEpoch %03d/%03d [%04d/%04d] -- ' %
            (self.epoch, self.n_epochs, self.batch, self.batches_epoch))

        for i, loss_name in enumerate(losses.keys()):
            if loss_name not in self.losses:
                self.losses[loss_name] = losses[loss_name].data[0]
            else:
                self.losses[loss_name] += losses[loss_name].data[0]

            if (i + 1) == len(losses.keys()):
                sys.stdout.write(
                    '%s: %.4f -- ' %
                    (loss_name, self.losses[loss_name] / self.batch))
            else:
                sys.stdout.write(
                    '%s: %.4f | ' %
                    (loss_name, self.losses[loss_name] / self.batch))

        batches_done = self.batches_epoch * (self.epoch - 1) + self.batch
        batches_left = self.batches_epoch * (
            self.n_epochs - self.epoch) + self.batches_epoch - self.batch
        sys.stdout.write('ETA: %s' % (datetime.timedelta(
            seconds=batches_left * self.mean_period / batches_done)))

        # # Draw images
        for image_name, tensor in images.items():
            if image_name not in self.image_windows:
                #self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), opts={'title':image_name})
                self.image_windows[image_name] = self.viz.heatmap(
                    tensor2image(tensor.data[1, :, :, :]),
                    opts={
                        'colormap': 'Greys',
                        'title': image_name
                    })
            else:
                #self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], opts={'title':image_name})
                self.viz.heatmap(tensor2image(tensor.data[1, :, :, :]),
                                 win=self.image_windows[image_name],
                                 opts={
                                     'colormap': 'Greys',
                                     'title': image_name
                                 })

        # End of epoch
        if (self.batch % self.batches_epoch) == 0:
            # Plot losses
            for loss_name, loss in self.losses.items():
                if loss_name not in self.loss_windows:
                    self.loss_windows[loss_name] = self.viz.line(
                        X=np.array([self.epoch]),
                        Y=np.array([loss / self.batch]),
                        opts={
                            'xlabel': 'epochs',
                            'ylabel': loss_name,
                            'title': loss_name
                        })
                else:
                    self.viz.line(X=np.array([self.epoch]),
                                  Y=np.array([loss / self.batch]),
                                  win=self.loss_windows[loss_name],
                                  update='append')
                # Reset losses for next epoch
                self.losses[loss_name] = 0.0

            self.epoch += 1
            self.batch = 1
            sys.stdout.write('\n')
        else:
            self.batch += 1
コード例 #34
0
class Callback(object):
    """A class representing routines called reactively at specific phases during trained.

    These can be used to log or visualize the training progress using any of the metric scores developed before.
    The values are stored at the end of each training epoch. The following metric scores are currently available:

        * :class:`~gensim.models.callbacks.CoherenceMetric`
        * :class:`~gensim.models.callbacks.PerplexityMetric`
        * :class:`~gensim.models.callbacks.DiffMetric`
        * :class:`~gensim.models.callbacks.ConvergenceMetric`

    """
    def __init__(self, metrics):
        """

        Parameters
        ----------
        metrics : list of :class:`~gensim.models.callbacks.Metric`
            The list of metrics to be reported by the callback.

        """
        self.metrics = metrics

    def set_model(self, model):
        """Save the model instance and initialize any required variables which would be updated throughout training.

        Parameters
        ----------
        model : :class:`~gensim.models.basemodel.BaseTopicModel`
            The model for which the training will be reported (logged or visualized) by the callback.

        """
        self.model = model
        self.previous = None
        # check for any metric which need model state from previous epoch
        if any(
                isinstance(metric, (DiffMetric, ConvergenceMetric))
                for metric in self.metrics):
            self.previous = copy.deepcopy(model)
            # store diff diagonals of previous epochs
            self.diff_mat = Queue()
        if any(metric.logger == "visdom" for metric in self.metrics):
            if not VISDOM_INSTALLED:
                raise ImportError("Please install Visdom for visualization")
            self.viz = Visdom()
            # store initial plot windows of every metric (same window will be updated with increasing epochs)
            self.windows = []
        if any(metric.logger == "shell" for metric in self.metrics):
            # set logger for current topic model
            self.log_type = logging.getLogger('gensim.models.ldamodel')

    def on_epoch_end(self, epoch, topics=None):
        """Report the current epoch's metric value.

        Called at the end of each training iteration.

        Parameters
        ----------
        epoch : int
            The epoch that just ended.
        topics : list of list of str, optional
            List of tokenized topics. This is required for the coherence metric.

        Returns
        -------
        dict of (str, object)
            Mapping from metric names to their values. The type of each value depends on the metric type,
            for example :class:`~gensim.models.callbacks.DiffMetric` computes a matrix while
            :class:`~gensim.models.callbacks.ConvergenceMetric` computes a float.

        """
        # stores current epoch's metric values
        current_metrics = {}

        # plot all metrics in current epoch
        for i, metric in enumerate(self.metrics):
            label = str(metric)
            value = metric.get_value(topics=topics,
                                     model=self.model,
                                     other_model=self.previous)

            current_metrics[label] = value

            if metric.logger == "visdom":
                if epoch == 0:
                    if value.ndim > 0:
                        diff_mat = np.array([value])
                        viz_metric = self.viz.heatmap(X=diff_mat.T,
                                                      env=metric.viz_env,
                                                      opts=dict(
                                                          xlabel='Epochs',
                                                          ylabel=label,
                                                          title=label))
                        # store current epoch's diff diagonal
                        self.diff_mat.put(diff_mat)
                        # saving initial plot window
                        self.windows.append(copy.deepcopy(viz_metric))
                    else:
                        viz_metric = self.viz.line(Y=np.array([value]),
                                                   X=np.array([epoch]),
                                                   env=metric.viz_env,
                                                   opts=dict(xlabel='Epochs',
                                                             ylabel=label,
                                                             title=label))
                        # saving initial plot window
                        self.windows.append(copy.deepcopy(viz_metric))
                else:
                    if value.ndim > 0:
                        # concatenate with previous epoch's diff diagonals
                        diff_mat = np.concatenate(
                            (self.diff_mat.get(), np.array([value])))
                        self.viz.heatmap(X=diff_mat.T,
                                         env=metric.viz_env,
                                         win=self.windows[i],
                                         opts=dict(xlabel='Epochs',
                                                   ylabel=label,
                                                   title=label))
                        self.diff_mat.put(diff_mat)
                    else:
                        self.viz.updateTrace(Y=np.array([value]),
                                             X=np.array([epoch]),
                                             env=metric.viz_env,
                                             win=self.windows[i])

            if metric.logger == "shell":
                statement = "".join(
                    ("Epoch ", str(epoch), ": ", label, " estimate: ",
                     str(value)))
                self.log_type.info(statement)

        # check for any metric which need model state from previous epoch
        if isinstance(metric, (DiffMetric, ConvergenceMetric)):
            self.previous = copy.deepcopy(self.model)

        return current_metrics