def args():
    a = _get_parser().parse_args()
    import laia.common.logging as log

    log.config_from_args(a)
    if a.print_args:
        import pprint

        log.get_logger(__name__).info("\n{}", pprint.pformat(vars(a)))
    return a
Exemplo n.º 2
0
 def _run_test_backward(self, implementation, dtype, device, reduction,
                        average_frames):
     ctc_logger = log.get_logger("laia.losses.ctc_loss")
     prev_level = ctc_logger.getEffectiveLevel()
     ctc_logger.setLevel(log.ERROR)
     # Size: T x N x 3
     x = torch.tensor(
         [
             [[0, 1, 2], [2, 3, 1], [0, 0, 1]],
             [[-1, -1, 1], [-3, -2, 2], [1, 0, 0]],
             [[0, 0, 0], [0, 0, 1], [1, 1, 1]],
             [[0, 0, 2], [0, 0, -1], [0, 2, 1]],
         ],
         dtype=dtype,
         device=device,
         requires_grad=True,
     )
     y = [[1], [1, 1, 2, 1], [1, 2, 2]]
     ctc = CTCLoss(
         reduction=reduction,
         average_frames=average_frames,
         implementation=implementation,
     )
     gradcheck(lambda x: ctc(x, y), (x, ))
     ctc_logger.setLevel(prev_level)
def check_nan(tensor, msg=None, name=None, raise_exception=False, **kwargs):
    r"""Check whether a tensor contains a NaN value.

    Arguments:
      tensor (torch.Tensor): tensor to check.
      msg (str): message format string. The message format can use the keys
          ``abs_num`` and ``rel_num`` to print the absolute number and the
           percentage of NaN elements. (Default: None)
      name (str): Name of the logger used to log the event (Default: None)
      raise_exception (bool): raise an exception instead of logging the event
          (Default: False)
      kwargs: additional named arguments passed to format the message.

    Return:
      `True` if the tensor contains any NaN element, or `False` otherwise.
    """
    logger = log.get_logger(name)
    if logger.isEnabledFor(log.DEBUG) and tensor.dtype in _TENSOR_REAL:
        num_nan = torch.isnan(tensor).sum().item()
        if num_nan > 0:
            per_nan = num_nan / tensor.numel()
            msg = ("{:d} ({:.2%}) NaN values found".format(num_nan, per_nan)
                   if msg is None else msg.format(
                       abs_num=num_nan, rel_num=per_nan, **kwargs))
            if raise_exception:
                raise ValueError(msg)
            else:
                logger.debug(msg)
            return True
    return False
Exemplo n.º 4
0
def check_tensor(
    tensor: torch.Tensor,
    msg: Optional[str] = None,
    name: Optional[str] = "laia",
    raise_exception: bool = False,
    **kwargs,
) -> bool:
    """
    Checks if each element of a tensor is finite or not.
    Real values are finite when they are not NaN, negative infinity, or infinity.

    Arguments:
      tensor (torch.Tensor): tensor to check.
      msg (str): message format string. The message format can use the keys
          ``abs_num`` and ``rel_num`` to print the absolute number and the
           percentage of infinite elements. (Default: None)
      name (str): Name of the logger used to log the event (Default: None)
      raise_exception (bool): raise an exception instead of logging the event
          (Default: False)
      kwargs: additional named arguments passed to format the message.

    Return:
      `True` if the tensor contains any infinite value, `False` otherwise.
    """
    logger = log.get_logger(name)
    if logger.isEnabledFor(DEBUG):
        num = torch.isfinite(tensor).logical_not().sum().item()
        if num > 0:
            percentage = num / tensor.numel()
            msg = (
                f"{num:d} ({percentage:.2%}) infinite values found"
                if msg is None
                else msg.format(abs_num=num, rel_num=percentage, **kwargs)
            )
            if raise_exception:
                raise ValueError(msg)
            logger.debug(msg)
            return True
    return False
Exemplo n.º 5
0
from __future__ import absolute_import

from typing import Tuple, Callable, Any as AnyT

from laia.common.logging import get_logger, DEBUG, INFO, ERROR

_logger = get_logger(__name__)


class Condition(object):
    """Conditions are objects that when called return either `True` or `False`.
    Typically used inside of Hooks to trigger an action

    Arguments:
        obj (Callable): obj from which a value will be retrieved
        key (Any, optional): Get this key from the obj after being called.
            Useful when the obj() returns a tuple/list/dict. (default: None)
    """

    def __init__(self, obj, key=None):
        # type: (Callable, AnyT) -> None
        self._obj = obj
        self._key = key

    def __call__(self):
        raise NotImplementedError

    def _process_value(self):
        value = self._obj()
        if value is None:
            # An exception happened during the computation
Exemplo n.º 6
0
from __future__ import absolute_import

import io
from os import listdir
from os.path import isfile, join, splitext

from torch._six import string_classes

import laia.common.logging as log
from laia.data.text_image_dataset import TextImageDataset

IMAGE_EXTENSIONS = ".jpg", ".png", ".jpeg", ".pbm", ".pgm", ".ppm", ".bmp"

_logger = log.get_logger(__name__)


class TextImageFromTextTableDataset(TextImageDataset):
    def __init__(
        self,
        txt_table,
        img_dirs,
        img_transform=None,
        txt_transform=None,
        img_extensions=IMAGE_EXTENSIONS,
        encoding="utf8",
    ):
        if isinstance(img_dirs, string_classes):
            img_dirs = [img_dirs]
        # First, load the transcripts and find the corresponding image filenames
        # in the given directory. Also save the IDs (basename) of the examples.
        self._ids, self._imgs, self._txts = _get_images_and_texts_from_text_table(
Exemplo n.º 7
0
                               engine_wrapper,
                               gpu=args.gpu))
    mo_saver_best_cer = RollingSaver(ModelCheckpointSaver(
        ckpt_saver('model.ckpt.lowest-valid-cer'), model),
                                     keep=2)
    mo_saver_best_wer = RollingSaver(ModelCheckpointSaver(
        ckpt_saver('model.ckpt.lowest-valid-wer'), model),
                                     keep=2)
    mo_saver = RollingSaver(
        ModelCheckpointSaver(ckpt_saver('model.ckpt'), model))

    @action
    def save(saver, epoch):
        saver.save(suffix=epoch)

    log.get_logger('laia.hooks.conditions.multiple_of').setLevel(log.WARNING)

    # Set hooks
    trainer.add_hook(
        EPOCH_END,
        HookList(
            # Save on best CER
            Hook(
                Lowest(engine_wrapper.valid_cer(), name='Lowest CER'),
                ActionList(Action(save, saver=tr_saver_best_cer),
                           Action(save, saver=mo_saver_best_cer))),
            # Save on best WER
            Hook(
                Lowest(engine_wrapper.valid_wer(), name='Lowest WER'),
                ActionList(Action(save, saver=tr_saver_best_wer),
                           Action(save, saver=mo_saver_best_wer))),