예제 #1
0
def create_tb_writer(outdir):
    """Return a tensorboard summarywriter with a custom scalar."""
    # This conditional import will raise an error if tensorboard<1.14
    from torch.utils.tensorboard import SummaryWriter

    tb_writer = SummaryWriter(log_dir=outdir)
    layout = {
        "Aggregate Charts": {
            "mean w/ min-max": [
                "Margin",
                ["eval/mean", "eval/min", "eval/max"],
            ],
            "mean +/- std": [
                "Margin",
                ["eval/mean", "extras/meanplusstdev", "extras/meanminusstdev"],
            ],
        }
    }
    tb_writer.add_custom_scalars(layout)
    return tb_writer
예제 #2
0
class Logger:
    _count = 0

    def __init__(self, scrn=True, log_dir='', phase=''):
        super().__init__()
        self._logger = logging.getLogger('logger_{}'.format(Logger._count))
        Logger._count += 1
        self._logger.setLevel(logging.DEBUG)

        if scrn:
            self._scrn_handler = logging.StreamHandler()
            self._scrn_handler.setLevel(logging.INFO)
            self._scrn_handler.setFormatter(
                logging.Formatter(fmt=FORMAT_SHORT))
            self._logger.addHandler(self._scrn_handler)

        if log_dir and phase:
            self.log_path = os.path.join(
                log_dir,
                '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format(
                    phase,
                    *localtime()[:6]))
            self.show_nl("log into {}\n\n".format(self.log_path))
            self._file_handler = logging.FileHandler(filename=self.log_path)
            self._file_handler.setLevel(logging.DEBUG)
            self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG))
            self._logger.addHandler(self._file_handler)

            self._writer = SummaryWriter(log_dir=os.path.join(
                log_dir, '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}'.format(
                    phase,
                    *localtime()[:6])))

    def show(self, *args, **kwargs):
        return self._logger.info(*args, **kwargs)

    def show_nl(self, *args, **kwargs):
        self._logger.info("")
        return self.show(*args, **kwargs)

    def dump(self, *args, **kwargs):
        return self._logger.debug(*args, **kwargs)

    def warning(self, *args, **kwargs):
        return self._logger.warning(*args, **kwargs)

    def error(self, *args, **kwargs):
        return self._logger.error(*args, **kwargs)

    # tensorboard
    def add_scalar(self, *args, **kwargs):
        return self._writer.add_scalar(*args, **kwargs)

    def add_scalars(self, *args, **kwargs):
        return self._writer.add_scalars(*args, **kwargs)

    def add_histogram(self, *args, **kwargs):
        return self._writer.add_histogram(*args, **kwargs)

    def add_image(self, *args, **kwargs):
        return self._writer.add_image(*args, **kwargs)

    def add_images(self, *args, **kwargs):
        return self._writer.add_images(*args, **kwargs)

    def add_figure(self, *args, **kwargs):
        return self._writer.add_figure(*args, **kwargs)

    def add_video(self, *args, **kwargs):
        return self._writer.add_video(*args, **kwargs)

    def add_audio(self, *args, **kwargs):
        return self._writer.add_audio(*args, **kwargs)

    def add_text(self, *args, **kwargs):
        return self._writer.add_text(*args, **kwargs)

    def add_graph(self, *args, **kwargs):
        return self._writer.add_graph(*args, **kwargs)

    def add_pr_curve(self, *args, **kwargs):
        return self._writer.add_pr_curve(*args, **kwargs)

    def add_custom_scalars(self, *args, **kwargs):
        return self._writer.add_custom_scalars(*args, **kwargs)

    def add_mesh(self, *args, **kwargs):
        return self._writer.add_mesh(*args, **kwargs)

    # def add_hparams(self, *args, **kwargs):
    #     return self._writer.add_hparams(*args, **kwargs)

    def flush(self):
        return self._writer.flush()

    def close(self):
        return self._writer.close()

    def _grad_hook(self, grad, name=None, grads=None):
        grads.update({name: grad})

    def watch_grad(self, model, layers):
        """
        Add hooks to the specific layers. Gradients of these layers will save to self.grads
        :param model:
        :param layers: Except a list eg. layers=[0, -1] means to watch the gradients of
                        the fist layer and the last layer of the model
        :return:
        """
        assert layers
        if not hasattr(self, 'grads'):
            self.grads = {}
            self.grad_hooks = {}
        named_parameters = list(model.named_parameters())
        for layer in layers:
            name = named_parameters[layer][0]
            handle = named_parameters[layer][1].register_hook(
                functools.partial(self._grad_hook, name=name,
                                  grads=self.grads))
            self.grad_hooks.update(dict(name=handle))

    def watch_grad_close(self):
        for _, handle in self.grad_hooks.items():
            handle.remove()  # remove the hook

    def add_grads(self, global_step=None, *args, **kwargs):
        """
        Add gradients to tensorboard. You must call the method self.watch_grad before using this method!
        """
        assert  hasattr(self, 'grads'),\
        "self.grads is nonexisent! You must call self.watch_grad before!"
        assert self.grads, "self.grads if empty!"
        for (name, grad) in self.grads.items():
            self.add_histogram(tag=name,
                               values=grad,
                               global_step=global_step,
                               *args,
                               **kwargs)

    @staticmethod
    def make_desc(counter, total, *triples):
        desc = "[{}/{}]".format(counter, total)
        # The three elements of each triple are
        # (name to display, AverageMeter object, formatting string)
        for name, obj, fmt in triples:
            desc += (" {} {obj.val:" + fmt + "} ({obj.avg:" + fmt +
                     "})").format(name, obj=obj)
        return desc
예제 #3
0
class SummaryWriter:
    def __init__(self, logdir, flush_secs=120):

        self.writer = TensorboardSummaryWriter(
            log_dir=logdir,
            purge_step=None,
            max_queue=10,
            flush_secs=flush_secs,
            filename_suffix='')

        self.global_step = None
        self.active = True

        # ------------------------------------------------------------------------
        # register add_* and set_* functions in summary module on instantiation
        # ------------------------------------------------------------------------
        this_module = sys.modules[__name__]
        list_of_names = dir(SummaryWriter)
        for name in list_of_names:

            # add functions (without the 'add' prefix)
            if name.startswith('add_'):
                setattr(this_module, name[4:], getattr(self, name))

            #  set functions
            if name.startswith('set_'):
                setattr(this_module, name, getattr(self, name))

    def set_global_step(self, value):
        self.global_step = value

    def set_active(self, value):
        self.active = value

    def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_audio(
                tag, snd_tensor, global_step=global_step, sample_rate=sample_rate, walltime=walltime)

    def add_custom_scalars(self, layout):
        if self.active:
            self.writer.add_custom_scalars(layout)

    def add_custom_scalars_marginchart(self, tags, category='default', title='untitled'):
        if self.active:
            self.writer.add_custom_scalars_marginchart(tags, category=category, title=title)

    def add_custom_scalars_multilinechart(self, tags, category='default', title='untitled'):
        if self.active:
            self.writer.add_custom_scalars_multilinechart(tags, category=category, title=title)

    def add_embedding(self, mat, metadata=None, label_img=None, global_step=None,
                      tag='default', metadata_header=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_embedding(
                mat, metadata=metadata, label_img=label_img, global_step=global_step,
                tag=tag, metadata_header=metadata_header)

    def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_figure(
                tag, figure, global_step=global_step, close=close, walltime=walltime)

    def add_graph(self, model, input_to_model=None, verbose=False):
        if self.active:
            self.writer.add_graph(model, input_to_model=input_to_model, verbose=verbose)

    def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_histogram(
                tag, values, global_step=global_step, bins=bins,
                walltime=walltime, max_bins=max_bins)

    def add_histogram_raw(self, tag, min, max, num, sum, sum_squares,
                          bucket_limits, bucket_counts, global_step=None,
                          walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_histogram_raw(
                tag, min=min, max=max, num=num, sum=sum, sum_squares=sum_squares,
                bucket_limits=bucket_limits, bucket_counts=bucket_counts,
                global_step=global_step, walltime=walltime)

    def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_image(
                tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)

    def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
                             walltime=None, rescale=1, dataformats='CHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_image_with_boxes(
                tag, img_tensor, box_tensor,
                global_step=global_step, walltime=walltime,
                rescale=rescale, dataformats=dataformats)

    def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_images(
                tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)

    def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_mesh(
                tag, vertices, colors=colors, faces=faces, config_dict=config_dict,
                global_step=global_step, walltime=walltime)

    def add_onnx_graph(self, graph):
        if self.active:
            self.writer.add_onnx_graph(graph)

    def add_pr_curve(self, tag, labels, predictions, global_step=None,
                     num_thresholds=127, weights=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_pr_curve(
                tag, labels, predictions, global_step=global_step,
                num_thresholds=num_thresholds, weights=weights, walltime=walltime)

    def add_pr_curve_raw(self, tag, true_positive_counts,
                         false_positive_counts,
                         true_negative_counts,
                         false_negative_counts,
                         precision,
                         recall,
                         global_step=None,
                         num_thresholds=127,
                         weights=None,
                         walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_pr_curve_raw(
                tag, true_positive_counts,
                false_positive_counts,
                true_negative_counts,
                false_negative_counts,
                precision,
                recall,
                global_step=global_step,
                num_thresholds=num_thresholds,
                weights=weights,
                walltime=walltime)

    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_scalar(
                tag, scalar_value, global_step=global_step, walltime=walltime)

    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_scalars(
                main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime)

    def add_text(self, tag, text_string, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_text(
                tag, text_string, global_step=global_step, walltime=walltime)

    def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_video(
                tag, vid_tensor, global_step=global_step, fps=fps, walltime=walltime)

    def close(self):
        self.writer.close()

    def __enter__(self):
        return self.writer.__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        return self.writer.__exit__(exc_type, exc_val, exc_tb)
예제 #4
0
class TensorBoard(object):
    supported_data_formats = ['csv', 'json']

    def __init__(self,
                 path='',
                 title='',
                 params=None,
                 res_iterations=False,
                 data_format='csv'):
        self.path = path
        self.title = title
        self.writer = SummaryWriter(log_dir=path)
        self.step = 0
        self.text_buffer = ''
        self.last_step = 0
        self.res_iterations = res_iterations
        self.writer.add_custom_scalars({
            'Error': {
                'Top_1':
                ['Multiline', ['training_error1', 'validation_error1']]
            }
        })
        self.training = True
        if params is not None:
            # TODO : hparams support for pytorch
            export_args_namespace(params, '{}/params.json'.format(path))

    def close(self):
        self.writer.close()
        self.writer.close()

    def set_resume_step(self, step):
        self.writer.kwargs['purge_step'] = step

    def set_training(self, training):
        self.training = training

    def update_step(self, step):
        self.step = step

    def log_results(self, step=None, **kwargs):
        step = step if step else self.step
        for k, v in kwargs.items():
            k = k.replace(' ', '_')
            self.writer.add_scalar(k, v, step)

    def log_buffers(self, step=None, **kwargs):
        step = step if step else self.step
        for k, v in kwargs.items():
            k = k.replace('.bn.', '.')
            self.writer.add_histogram(k, v, step)

    def log_delay(self, delay_dist, step=None):
        step = step if step else self.step
        mean_delay = sum(k * v for k, v in delay_dist.items()) / sum(
            delay_dist.values())
        fig = plt.figure(figsize=(12, 8))
        plt.bar(delay_dist.keys(),
                delay_dist.values(),
                width=1.0,
                color='xkcd:blue',
                edgecolor='black')
        plt.axvline(x=mean_delay, color='firebrick')
        self.writer.add_figure('Server/delay_distribution',
                               fig,
                               global_step=step)

    def log_text(self, text, step=None):
        self.writer.add_text(text, step if step else self.step)

    def log_model(self, server, step=None):
        step = step if step else self.step
        if hasattr(server,
                   '_shards_weights') and server._shards_weights is not None:
            for k, v in server.get_workers_mean_statistics().items():
                self.writer.add_scalar('Server/workers_mean_statistics/' + k,
                                       v, step)
            for k, v in server.get_workers_master_statistics().items():
                self.writer.add_scalar('Server/workers_master_statistics/' + k,
                                       v, step)
            self.writer.add_scalar('Server/mean_master_distance',
                                   server.get_mean_master_dist(), step)

        self.writer.add_scalar('Model/weights_distance_from_init',
                               server.get_server_weights_dist_norm(), step)
        weights_norm, gradients_norm = server.get_server_norms()
        self.writer.add_scalar('Model/gradients_norm', gradients_norm, step)
        self.writer.add_scalar('Model/weights_norm', weights_norm, step)
        for k, v in server.get_optimizer_regime().items():
            self.writer.add_scalar('Regime/' + k, v, step)
예제 #5
0
        print(
            f'    Continuing from checkpoint {chk_data["checkpointN"]+1}, batch {chk_data["lastEpoch"]+1}/{chk_data["lastBatch"]+1}'
        )

    writer = SummaryWriter('log', purge_step=chk_data['lastBatchId'])
    writer.add_custom_scalars({
        'Log-loss': {
            'Batch': ['Multiline', ['Log-loss/train', 'Log-loss/eval']],
            'Epoch average': [
                'Multiline',
                ['Log-loss-epoch-avg/train', 'Log-loss-epoch-avg/eval']
            ]
        },
        'Loss': {
            'Batch': ['Multiline', ['Loss/train', 'Loss/eval']],
            'Epoch average':
            ['Multiline', ['Loss-epoch-avg/train', 'Loss-epoch-avg/eval']]
        },
        'Likelihood': {
            'Batch': ['Multiline', ['Likelihood/train', 'Likelihood/eval']],
            'Epoch average': [
                'Multiline',
                ['Likelihood-epoch-avg/train', 'Likelihood-epoch-avg/eval']
            ]
        }
    })

    to_tensor = ToTensor()

    def to_pilimagelab(pic):
        pic = pic.mul(255).byte()
예제 #6
0
loss_names = ['_total_', 'class', 'mask', 'rpn_bbox', 'rpn_score']
metric_names = [
    'bbox_AP_0.25', 'bbox_AP_0.5', 'bbox_mAP_0.25', 'bbox_mAP_0.5',
    'mask_AP_0.25', 'mask_AP_0.5', 'mask_mAP_0.25', 'mask_mAP_0.5',
    'segment_avg_iou'
]
section_names = ['trainval', 'val']
layout = get_layout({
    'loss': loss_names,
    'metric': metric_names
}, section_names)

# %%

train_writer.add_custom_scalars(layout)

val_losses, val_metrics, train_losses, train_metrics = train(
    model,
    loss,
    train_loader,
    val_loader,
    trainval_loader,
    optimizer,
    scheduler,
    device,
    batch_to_device,
    experiment,
    checkpoint_storage,
    loss_checkpoint_storage,
    train_writer,
예제 #7
0
파일: train.py 프로젝트: NumberChiffre/LTPA
class AttentionNetwork:
    def __init__(
        self,
        opt: argparse.ArgumentParser,
        model: torch.nn.Module,
        criterion: torch.nn,
        optimizer: optim.Optimizer,
        scheduler: lr_scheduler.LambdaLR,
        device: torch.device,
        loglevel: int = 20,
    ):
        self.opt = opt
        self.params = vars(self).copy()
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.writer = SummaryWriter(
            log_dir=f'{ROOT_DIR}/{opt.logs_path}/tensorboard/'
            f'{strftime("%Y-%m-%d", gmtime())}/'
            f'{str(self.params)}_{strftime("%Y-%m-%d %H-%M-%S", gmtime())}')
        self.logger = ProjectLogger(level=loglevel)
        self.images = []
        self.image_dim = int(5 * 5)
        self.step = 0
        self.min_up_factor = 2

    def train_validate(self, train_loader: torch.utils.data.DataLoader,
                       test_loader: torch.utils.data.DataLoader):
        for epoch in range(self.opt.epochs):
            self.train(train_loader=train_loader, epoch=epoch)
            self.test(test_loader=test_loader, epoch=epoch)
        tb_layout = {
            'Training': {
                'Losses':
                ['Multiline', ['epoch_train_loss', 'epoch_test_loss']],
                'Accuracy':
                ['Multiline', ['epoch_train_acc', 'epoch_test_acc']],
            }
        }
        self.writer.add_custom_scalars(tb_layout)
        self.writer.close()

    def train(self, train_loader: torch.utils.data.DataLoader, epoch: int):
        self.writer.add_scalar('train_learning_rate',
                               self.optimizer.param_groups[0]['lr'], epoch)
        self.logger.info(f'epoch {epoch} completed')
        self.scheduler.step()
        epoch_loss, epoch_acc = [], []
        for batch_idx, (inputs, targets) in enumerate(train_loader, 0):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            self.model.train()
            self.model.zero_grad()
            self.optimizer.zero_grad()
            if batch_idx == 0:
                self.images.append(inputs[0:self.image_dim, :, :, :])
            pred, _, _, _ = self.model(inputs)
            loss = self.criterion(pred, targets)
            loss.backward()
            self.optimizer.step()
            predict = torch.argmax(pred, 1)
            total = targets.size(0)
            correct = torch.eq(predict, targets).cpu().sum().item()
            acc = correct / total
            if batch_idx % 10 == 0:
                self.logger.info(f"[epoch {epoch}][batch_idx {batch_idx}]"
                                 f"loss: {round(loss.item(), 4)} "
                                 f"accuracy: {100 * acc}% ")
            epoch_loss += [loss.item()]
            epoch_acc += [acc]
            self.step += 1
            self.writer.add_scalar('train_loss', loss.item(), self.step)
            self.writer.add_scalar('train_acc', acc, self.step)

        # log/add on tensorboard
        train_loss = np.mean(epoch_loss, axis=0)
        train_acc = np.mean(epoch_acc, axis=0)
        self.writer.add_scalar('epoch_train_loss', train_loss, epoch)
        self.writer.add_scalar('epoch_train_acc', train_acc, epoch)
        self.logger.info(f"[epoch {epoch}] train_acc: {100 * train_acc}%")
        self.logger.info(f"[epoch {epoch}] train_loss: {train_loss}")
        self.inputs = inputs

        # save model params
        os.makedirs(f'{ROOT_DIR}/{opt.logs_path}/model_states', exist_ok=True)
        torch.save(
            self.model.state_dict(),
            f'{ROOT_DIR}/{opt.logs_path}/model_states/net_epoch_{epoch}.pth')
        return train_loss, train_acc

    def test(self, test_loader: torch.utils.data.DataLoader, epoch: int):
        epoch_loss, epoch_acc = [], []
        self.model.eval()
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader, 0):
                inputs, targets = inputs.to(self.device), targets.to(
                    self.device)
                if batch_idx == 0:
                    self.images.append(self.inputs[0:self.image_dim, :, :, :])
                pred, _, _, _ = self.model(inputs)
                loss = self.criterion(pred, targets)
                predict = torch.argmax(pred, 1)
                total = targets.size(0)
                correct = torch.eq(predict, targets).cpu().sum().item()
                acc = correct / total
                epoch_loss += [loss.item()]
                epoch_acc += [acc]

            # log/add on tensorboard
            test_loss = np.mean(epoch_loss, axis=0)
            test_acc = np.mean(epoch_acc, axis=0)
            self.writer.add_scalar('epoch_test_loss', test_loss, epoch)
            self.writer.add_scalar('epoch_test_acc', test_acc, epoch)
            self.logger.info(f"[epoch {epoch}] test_acc: {100 * test_acc}%")
            self.logger.info(f"[epoch {epoch}] test_loss: {test_loss}")

            # initial image..
            if epoch == 0:
                self.train_image = utils.make_grid(
                    self.images[0],
                    nrow=int(np.sqrt(self.image_dim)),
                    normalize=True,
                    scale_each=True)
                self.test_image = utils.make_grid(self.images[1],
                                                  nrow=int(
                                                      np.sqrt(self.image_dim)),
                                                  normalize=True,
                                                  scale_each=True)
                self.writer.add_image('train_image', self.train_image, epoch)
                self.writer.add_image('test_image', self.test_image, epoch)

            # training image sets
            __, ae1, ae2, ae3 = self.model(self.images[0])
            attn1 = plot_attention(self.train_image,
                                   ae1,
                                   up_factor=self.min_up_factor,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('train_attention_map_1', attn1, epoch)

            attn2 = plot_attention(self.train_image,
                                   ae2,
                                   up_factor=self.min_up_factor * 2,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('train_attention_map_2', attn2, epoch)

            attn3 = plot_attention(self.train_image,
                                   ae3,
                                   up_factor=self.min_up_factor * 4,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('train_attention_map_3', attn3, epoch)

            # validation image sets
            __, ae1, ae2, ae3 = self.model(self.images[1])
            attn1 = plot_attention(self.test_image,
                                   ae1,
                                   up_factor=self.min_up_factor,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('test_attention_map_1', attn1, epoch)
            attn2 = plot_attention(self.test_image,
                                   ae2,
                                   up_factor=self.min_up_factor * 2,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('test_attention_map_2', attn2, epoch)
            attn3 = plot_attention(self.test_image,
                                   ae3,
                                   up_factor=self.min_up_factor * 4,
                                   nrow=int(np.sqrt(self.image_dim)))
            self.writer.add_image('test_attention_map_3', attn3, epoch)
            return test_loss, test_acc
예제 #8
0
                                       global_iteration, k))
                t_k_1 = t_k

                print('Executed in:', rospy.get_time() - t_k)
                print('-' * 10)

            elif NN.Thrust_on and len(Buffer.PWM) >= Buffer.train_batch_size:

                k = k + 1
                t_k = rospy.get_time()
                Buffer.time.appendleft(t_k)

            else:
                print('Processing not started')
                rospy.sleep(0.02)

        layout = {
            'Acceler': {
                'ax': ['Multiline', ['ax_true', 'ax_estimated']],
                'az': ['Multiline', ['az_true', 'az_estimated']]
            },
            'q_dot_central_diff': {
                'q': ['Multiline', ['q_true', 'q_estimated']]
            }
        }
        writer.add_custom_scalars(layout)
        visualisation(NN, Buffer, validation_array)

    except rospy.ROSInterruptException:
        pass