def __init__(self, checkpoint: Path, metrics: List[str], rank: Optional[int] = None, period: int = 100, tensorboard: bool = True, reduction_tag: str = "none") -> None: # NOTE on reduction_tag: # 1) for asr tasks we use #tok (token level) # 2) for sse tasks we use $utt (utterance level) self.rank = rank self.period = period self.reduction_tag = reduction_tag # mkdir checkpoint.mkdir(parents=True, exist_ok=True) if rank is None: logger_loc = (checkpoint / "trainer.log").as_posix() self.header = "Trainer" else: logger_loc = (checkpoint / f"trainer.rank.{rank}.log").as_posix() self.header = f"Rank {rank}" self.logger = get_logger(logger_loc, file=True) # only for rank-0 if tensorboard and rank in [0, None]: if not tensorboard_available: warnings.warn("tensorboard not installed thus disable it...") self.board_writer = None else: self.board_writer = SummaryWriter(checkpoint) else: self.board_writer = None self.metrics = metrics self.reset()
# Copyright 2020 Jian Wu # License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """ Beam search for transformer based AM (Transformer decoder) """ import torch as th import torch.nn as nn import torch.nn.functional as tf from aps.asr.beam_search.utils import BeamSearchParam, BeamTracker, BatchBeamTracker from aps.asr.beam_search.lm import lm_score_impl, LmType from aps.utils import get_logger from typing import List, Dict, Optional logger = get_logger(__name__) def greedy_search(decoder: nn.Module, enc_out: th.Tensor, sos: int = -1, eos: int = -1, len_norm: bool = True) -> List[Dict]: """ Greedy search (for debugging, should equal to beam search with #beam-size == 1) """ if sos < 0 or eos < 0: raise RuntimeError(f"Invalid SOS/EOS ID: {sos:d}/{eos:d}") # T x N x D _, N, _ = enc_out.shape if N != 1: