예제 #1
0
class TensorboardWriter:
    def __init__(self, log_dir: str, *args: Any, **kwargs: Any):
        r"""A Wrapper for tensorboard SummaryWriter. It creates a dummy writer
        when log_dir is empty string or None. It also has functionality that
        generates tb video directly from numpy images.

        Args:
            log_dir: Save directory location. Will not write to disk if
            log_dir is an empty string.
            *args: Additional positional args for SummaryWriter
            **kwargs: Additional keyword args for SummaryWriter
        """
        self.writer = None
        if log_dir is not None and len(log_dir) > 0:
            self.writer = SummaryWriter(log_dir, *args, **kwargs)

    def __getattr__(self, item):
        if self.writer:
            return self.writer.__getattribute__(item)
        else:
            return lambda *args, **kwargs: None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.writer:
            self.writer.close()
예제 #2
0
class TensorboardWriter:
    def __init__(self, log_dir: str, *args: Any, **kwargs: Any):
        r"""A Wrapper for tensorboard SummaryWriter. It creates a dummy writer
        when log_dir is empty string or None. It also has functionality that
        generates tb video directly from numpy images.

        Args:
            log_dir: Save directory location. Will not write to disk if
            log_dir is an empty string.
            *args: Additional positional args for SummaryWriter
            **kwargs: Additional keyword args for SummaryWriter
        """
        self.writer = None
        if log_dir is not None and len(log_dir) > 0:
            self.writer = SummaryWriter(log_dir, *args, **kwargs)

    def __getattr__(self, item):
        if self.writer:
            return self.writer.__getattribute__(item)
        else:
            return lambda *args, **kwargs: None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.writer:
            self.writer.close()

    def add_video_from_np_images(self,
                                 video_name: str,
                                 step_idx: int,
                                 images: np.ndarray,
                                 fps: int = 10) -> None:
        r"""Write video into tensorboard from images frames.

        Args:
            video_name: name of video string.
            step_idx: int of checkpoint index to be displayed.
            images: list of n frames. Each frame is a np.ndarray of shape.
            fps: frame per second for output video.

        Returns:
            None.
        """
        if not self.writer:
            return
        # initial shape of np.ndarray list: N * (H, W, 3)
        frame_tensors = [
            torch.from_numpy(np_arr).unsqueeze(0) for np_arr in images
        ]
        video_tensor = torch.cat(tuple(frame_tensors))
        video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0)
        # final shape of video tensor: (1, n, 3, H, W)
        self.writer.add_video(video_name,
                              video_tensor,
                              fps=fps,
                              global_step=step_idx)
예제 #3
0
class TensorboardLogger(object):
    def __init__(self):
        self._logger = None
        self._global_step = 0

    def create(self, path):
        self._logger = SummaryWriter(path)

    def noop(self, *args, **kwargs):
        return

    def step(self):
        self._global_step += 1

    @property
    def global_step(self):
        return self._global_step

    def log_scaler_dict(self, log_dict, prefix=''):
        """ log a dictionary of scalar values"""
        if self._logger is None:
            return
        if prefix:
            prefix = f'{prefix}_'
        for name, value in log_dict.items():
            if isinstance(value, dict):
                self.log_scaler_dict(value,
                                     self._global_step,
                                     prefix=f'{prefix}{name}')
            else:
                self._logger.add_scalar(f'{prefix}{name}', value,
                                        self._global_step)

    def __getattr__(self, name):
        if self._logger is None:
            return self.noop
        return self._logger.__getattribute__(name)
예제 #4
0
class Logger(object):
    def __init__(self,
                 log_dir=None,
                 tensorboard=False,
                 txt=False,
                 logfile=None,
                 **kwargs):

        self.tblog = None
        if tensorboard:
            try:
                self.tblog = SummaryWriter(log_dir=log_dir, **kwargs)
            except:
                self.tblog = None

        self.txtlog = None
        if txt:
            if logfile is None:
                logfile = open(os.path.join(log_dir, 'logfile.txt'), 'a+')
            self._old_stdout = sys.stdout
            self.txtlog = Tee(self._old_stdout, logfile)
            sys.stdout = self.txtlog

            now = time.strftime("%b-%d-%Y-%H%M%S")
            title = '**** Beginning Log {} ****\n'.format(now)
            title_stars = '*' * (len(title) - 1) + '\n'

            self.txtlog.write(title_stars + title + title_stars, logonly=True)

        self.global_step = None
        self.tag_fmt = None

    def set_step(self, step):
        self.global_step = step

    def get_step(self):
        return self.global_step

    def set_tag_format(self, fmt=None):
        self.tag_fmt = fmt

    def get_tag_format(self):
        return self.tag_fmt

    def add_hparams(self, param_dict, metrics={}):
        self.tblog.add_hparams(param_dict, metrics)

    def add(self,
            data_type,
            tag,
            *args,
            global_step=None,
            **kwargs):  # TODO: test and maybe clean

        if self.tblog is None:
            return None

        # if data_type == 'scalar' and not isinstance(args[0], float):
        # 	print(tag, type(args[0]), args[0])
        # 	assert False

        add_fn = self.tblog.__getattribute__('add_{}'.format(data_type))

        if global_step is None:
            global_step = self.global_step

        if self.tag_fmt is not None:
            tag = self.tag_fmt.format(tag)

        add_fn(tag, *args, global_step=global_step, **kwargs)

    def flush(self):
        if self.tblog is not None:
            self.tblog.flush()
        if self.txtlog is not None:
            self.txtlog.flush()