Esempio n. 1
0
    def __init__(
        self,
        *,
        sample_rate=16000,
        n_window_size=320,
        n_window_stride=160,
        window="hann",
        normalize="per_feature",
        n_fft=None,
        preemph=0.97,
        nfilt=64,
        lowfreq=0,
        highfreq=None,
        log=True,
        log_zero_guard_type="add",
        log_zero_guard_value=2**-24,
        dither=CONSTANT,
        pad_to=16,
        max_duration=16.7,
        frame_splicing=1,
        stft_conv=False,
        pad_value=0,
        mag_power=2.0,
    ):
        super(FilterbankFeatures, self).__init__()
        if (n_window_size is None or n_window_stride is None
                or not isinstance(n_window_size, int)
                or not isinstance(n_window_stride, int) or n_window_size <= 0
                or n_window_stride <= 0):
            raise ValueError(
                f"{self} got an invalid value for either n_window_size or "
                f"n_window_stride. Both must be positive ints.")
        nemo.logging.info(f"PADDING: {pad_to}")

        self.win_length = n_window_size
        self.hop_length = n_window_stride
        self.n_fft = n_fft or 2**math.ceil(math.log2(self.win_length))
        self.stft_conv = stft_conv

        if stft_conv:
            nemo.logging.info("STFT using conv")

            # Create helper class to patch forward func for use with AMP
            class STFTPatch(STFT):
                def __init__(self, *params, **kw_params):
                    super(STFTPatch, self).__init__(*params, **kw_params)

                def forward(self, input_data):
                    return super(STFTPatch, self).transform(input_data)[0]

            self.stft = STFTPatch(self.n_fft, self.hop_length, self.win_length,
                                  window)

        else:
            print("STFT using torch")
            torch_windows = {
                'hann': torch.hann_window,
                'hamming': torch.hamming_window,
                'blackman': torch.blackman_window,
                'bartlett': torch.bartlett_window,
                'none': None,
            }
            window_fn = torch_windows.get(window, None)
            window_tensor = window_fn(self.win_length,
                                      periodic=False) if window_fn else None
            self.register_buffer("window", window_tensor)
            self.stft = lambda x: torch.stft(
                x,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
                win_length=self.win_length,
                center=True,
                window=self.window.to(dtype=torch.float),
            )

        self.normalize = normalize
        self.log = log
        self.dither = dither
        self.frame_splicing = frame_splicing
        self.nfilt = nfilt
        self.preemph = preemph
        self.pad_to = pad_to
        highfreq = highfreq or sample_rate / 2

        filterbanks = torch.tensor(
            librosa.filters.mel(
                sample_rate,
                self.n_fft,
                n_mels=nfilt,
                fmin=lowfreq,
                fmax=highfreq,
            ),
            dtype=torch.float,
        ).unsqueeze(0)
        # self.fb = filterbanks
        # self.window = window_tensor
        self.register_buffer("fb", filterbanks)

        # Calculate maximum sequence length
        max_length = self.get_seq_len(
            torch.tensor(max_duration * sample_rate, dtype=torch.float))
        max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
        self.max_length = max_length + max_pad
        self.pad_value = pad_value
        self.mag_power = mag_power

        # We want to avoid taking the log of zero
        # There are two options: either adding or clamping to a small value
        if log_zero_guard_type not in ["add", "clamp"]:
            raise ValueError(
                f"{self} received {log_zero_guard_type} for the "
                f"log_zero_guard_type parameter. It must be either 'add' or "
                f"'clamp'.")
        # log_zero_guard_value is the the small we want to use, we support
        # an actual number, or "tiny", or "eps"
        self.log_zero_guard_value = lambda _: log_zero_guard_value
        if isinstance(log_zero_guard_value, str):
            if log_zero_guard_value == "tiny":
                self.log_zero_guard_value = lambda x: torch.finfo(x.dtype).tiny
            elif log_zero_guard_value == "eps":
                self.log_zero_guard_value = lambda x: torch.finfo(x.dtype).eps
            else:
                raise ValueError(
                    f"{self} received {log_zero_guard_value} for the "
                    f"log_zero_guard_type parameter. It must be either a "
                    f"number, 'tiny', or 'eps'")
        self.log_zero_guard_type = log_zero_guard_type
Esempio n. 2
0
import numpy as np
import torch
from tqdm import trange
from typeguard import check_argument_types

from espnet.utils.cli_utils import get_commandline_args
from espnet2.fileio.sound_scp import SoundScpWriter
from espnet2.tasks.enh import EnhancementTask
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed
from espnet2.utils import config_argparse
from espnet2.utils.types import str2bool
from espnet2.utils.types import str2triple_str
from espnet2.utils.types import str_or_none

EPS = torch.finfo(torch.get_default_dtype()).eps


class SeparateSpeech:
    """SeparateSpeech class

    Examples:
        >>> import soundfile
        >>> separate_speech = SeparateSpeech("enh_config.yml", "enh.pth")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> separate_speech(audio)
        [separated_audio1, separated_audio2, ...]

    """
    def __init__(
        self,
Esempio n. 3
0
def _reciprocal(x):
    result = x.reciprocal().clamp(max=torch.finfo(x.dtype).max)
    return result
def small_val(dtype):
    return torch.finfo(dtype).tiny
Esempio n. 5
0
def log_p_to_entropy(log_probs):
    min_real = torch.finfo(log_probs.dtype).min
    clamped_log_probs = torch.clamp(log_probs, min=min_real)
    p_log_p = log_probs.exp() * clamped_log_probs

    return -p_log_p.sum(-1)
Esempio n. 6
0
def get_rnnt_logprobs(
    lm: Tensor,
    am: Tensor,
    symbols: Tensor,
    termination_symbol: int,
    boundary: Optional[Tensor] = None,
    modified: bool = False,
) -> Tuple[Tensor, Tensor]:
    """
    Reduces RNN-T problem (the simple case, where joiner network is just
    addition), to a compact, standard form that can then be given
    (with boundaries) to mutual_information_recursion().
    This function is called from rnnt_loss_simple(), but may be useful for
    other purposes.

    Args:
      lm:
        Language model part of un-normalized logprobs of symbols, to be added to
        acoustic model part before normalizing.  Of shape::

           [B][S+1][C]

        where B is the batch size, S is the maximum sequence length of
        the symbol sequence, possibly including the EOS symbol; and
        C is size of the symbol vocabulary, including the termination/next-frame
        symbol.
        Conceptually, lm[b][s] is a vector of length [C] representing the
        "language model" part of the un-normalized logprobs of symbols,
        given all symbols *earlier than* s in the sequence.  The reason
        we still need this for position S is that we may still be emitting
        the termination/next-frame symbol at this point.
      am:
        Acoustic-model part of un-normalized logprobs of symbols, to be added
        to language-model part before normalizing.  Of shape::

           [B][T][C]

        where B is the batch size, T is the maximum sequence length of
        the acoustic sequences (in frames); and C is size of the symbol
        vocabulary, including the termination/next-frame symbol.  It reflects
        the "acoustic" part of the probability of any given symbol appearing
        next on this frame.
      symbols:
        A LongTensor of shape [B][S], containing the symbols at each position
        of the sequence.
      termination_symbol:
        The identity of the termination symbol, must be in {0..C-1}
      boundary:
        a optional LongTensor of shape [B, 4] with elements interpreted as
        [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
        [0, 0, S, T]
        if boundary is not supplied.
        Most likely you will want begin_symbol and begin_frame to be zero.
       modified: if True, each time a real symbol is consumed a frame will
           also be consumed, so at most 1 symbol can appear per frame.
    Returns:
        (px, py) (the names are quite arbitrary).
           px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
           py: logprobs, of shape [B][S+1][T]

      in the recursion::

          p[b,0,0] = 0.0
          if !modified:
             p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
                                p[b,s,t-1] + py[b,s,t-1])
          if modified:
             p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
                                p[b,s,t-1] + py[b,s,t-1])
          .. where p[b][s][t] is the "joint score" of the pair of subsequences
          of length s and t respectively.  px[b][s][t] represents the
          probability of extending the subsequences of length (s,t) by one in
          the s direction, given the particular symbol, and py[b][s][t]
          represents the probability of extending the subsequences of length
          (s,t) by one in the t direction,
          i.e. of emitting the termination/next-frame symbol.

          if !modified, px[:,:,T] equals -infinity, meaning on the
          "one-past-the-last" frame we cannot emit any symbols.
          This is simply a way of incorporating
          the probability of the termination symbol on the last frame.
    """
    assert lm.ndim == 3
    assert am.ndim == 3
    assert lm.shape[0] == am.shape[0]
    assert lm.shape[2] == am.shape[2]

    (B, T, C) = am.shape
    S = lm.shape[1] - 1
    assert symbols.shape == (B, S)

    # subtracting am_max and lm_max is to ensure the probs are in a good range
    # to do exp() without causing underflow or overflow.
    am_max, _ = torch.max(am, dim=2, keepdim=True)  # am_max: [B][T][1]
    lm_max, _ = torch.max(lm, dim=2, keepdim=True)  # lm_max: [B][S+1][1]
    am_probs = (am - am_max).exp()
    lm_probs = (lm - lm_max).exp()
    # normalizers: [B][S+1][T]
    normalizers = (torch.matmul(lm_probs, am_probs.transpose(1, 2)) +
                   torch.finfo(am_probs.dtype).tiny).log()

    # add lm_max and am_max to normalizers, to make it as if we had not
    # subtracted am_max and lm_max above.
    normalizers = normalizers + lm_max + am_max.transpose(1, 2)  # [B][S+1][T]

    # px is the probs of the actual symbols..
    px_am = torch.gather(
        am.unsqueeze(1).expand(B, S, T, C),
        dim=3,
        index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
    ).squeeze(-1)  # [B][S][T]

    if not modified:
        px_am = torch.cat(
            (
                px_am,
                torch.full(
                    (B, S, 1),
                    float("-inf"),
                    device=px_am.device,
                    dtype=px_am.dtype,
                ),
            ),
            dim=2,
        )  # now: [B][S][T+1], index [:,:,T] has -inf..

    px_lm = torch.gather(lm[:, :S], dim=2,
                         index=symbols.unsqueeze(-1))  # [B][S][1]

    px = px_am + px_lm  # [B][S][T+1], last slice with indexes out of
    # boundary is  -inf
    px[:, :, :T] -= normalizers[:, :S, :]  # px: [B][S][T+1]

    # py is the probs of termination symbols, of shape [B][S+1][T]
    py_am = am[:, :, termination_symbol].unsqueeze(1)  # [B][1][T]
    py_lm = lm[:, :, termination_symbol].unsqueeze(2)  # [B][S+1][1]
    py = py_am + py_lm - normalizers

    if not modified:
        px = fix_for_boundary(px, boundary)

    return (px, py)
def pos_inf(dtype):
    return torch.finfo(dtype).max
Esempio n. 8
0
def clamp_probs(probs):
    eps = torch.finfo(probs.dtype).eps
    return probs.clamp(min=eps, max=1 - eps)
    nn.MultiMarginLoss,
    nn.TripletMarginLoss,
]

LAYERS_TYPES = {
    "conv": CONV_LAYERS,
    "linear": LINEAR_LAYERS,
    "pool": POOL_LAYERS,
    "pad": PAD_LAYERS,
    "activation": ACTIVATION_LAYERS,
    "normalization": NORM_LAYERS,
    "dropout": DROPOUT_LAYERS,
    "loss": LOSS_LAYERS,
}

TORCH_FLOAT_MAX = torch.tensor(torch.finfo(torch.float32).max,
                               dtype=torch.float32)
TORCH_FLOAT_EPS = torch.tensor(torch.finfo(torch.float32).eps,
                               dtype=torch.float32)

MAX_VALUES = dict()


def _get_max_value(exp: int, man: int):
    global MAX_VALUES

    key = (exp, man)

    if key in MAX_VALUES:
        return MAX_VALUES[key]
Esempio n. 10
0
import functools
import torch
import torch.distributed as dist

from enum import Enum

TORCH_HALF_MIN = torch.finfo(torch.float16).min
TORCH_HALF_MAX = torch.finfo(torch.float16).max


class DQuantType(Enum):
    FP16 = "fp16"

    def __str__(self) -> str:
        return self.value


def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
    return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()


def _quantize_tensor(tensor, qtype):
    if not isinstance(tensor, torch.Tensor):
        raise RuntimeError(
            f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
        )
    if (qtype == DQuantType.FP16):
        return _fp32_to_fp16_with_clamp(tensor)
    else:
        raise RuntimeError(f'Quantization type {qtype} is not supported')
Esempio n. 11
0
import torch
import torch.nn as nn

EPSILON = torch.finfo(torch.float32).eps


# ================================ supervised Net ======================================
class SuperLayer(nn.Module):
    """
    Multiplicative update with Frobenius norm
    This can fit L1, L2 regularization.
    """
    def __init__(self, comp, features, L1, L2):
        super(SuperLayer, self).__init__()
        self.l_1 = L1
        self.l_2 = L2
        # an affine operation: y = Wx +b
        self.fc1 = nn.Linear(comp, comp, bias=False)
        self.fc2 = nn.Linear(features, comp, bias=False)

    def forward(self, y, x):
        denominator = torch.add(self.fc1(y), self.l_2 * y + self.l_1 + EPSILON)
        numerator = self.fc2(x)
        delta = torch.div(numerator, denominator)
        return torch.mul(delta, y)


class SuperNet(nn.Module):
    """
    Class for a Regularized DNMF with varying layers number.
    Input:
Esempio n. 12
0
def max_neg_value(t):
    return -torch.finfo(t.dtype).max
Esempio n. 13
0
def joint_mutual_information_recursion(
    px: Sequence[Tensor],
    py: Sequence[Tensor],
    boundary: Optional[Tensor] = None,
) -> Sequence[Tensor]:
    """A recursion that is useful for modifications of RNN-T and similar loss
    functions, where the recursion probabilities have a number of terms and you
    want them reported separately.  See mutual_information_recursion() for more
    documentation of the basic aspects of this.

    Args:
      px:
        a sequence of Tensors, each of the same shape [B][S][T+1]
      py:
        a sequence of Tensor, each of the same shape [B][S+1][T],
        the sequence must be the same length as px.
      boundary:
        optionally, a LongTensor of shape [B][4] containing rows
        [s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end <= S
        and 0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
        These are the beginning and one-past-the-last positions in the x
        and y sequences respectively, and can be used if not all
        sequences are of the same length.
    Returns:
      a Tensor of shape (len(px), B),
      whose sum over dim 0 is the total log-prob of the recursion mentioned
      below, per sequence. The first element of the sequence of length len(px)
      is "special", in that it has an offset term reflecting the difference
      between sum-of-log and log-of-sum; for more interpretable loss values,
      the "main" part of your loss function should be first.

      The recursion below applies if boundary == None, when it defaults
      to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px
      and py::

          p = tensor of shape (B, S+1, T+1), containing -infinity
          p[b,0,0] = 0.0
          # do the following in loop over s and t:
          p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
                              p[b,s,t-1] + py_sum[b,s,t-1])
                      (if s > 0 or t > 0)
          return b[:][S][T]

    This function lets you implement the above recursion efficiently, except
    that it gives you a breakdown of the contribution from all the elements of
    px and py separately.  As noted above, the first element of the
    sequence is "special".
    """
    N = len(px)
    assert len(py) == N and N > 0
    B, S, T1 = px[0].shape
    T = py[0].shape[2]
    assert T1 in [T, T + 1]  # T if modified...
    assert py[0].shape == (B, S + 1, T)
    assert px[0].dtype == py[0].dtype

    px_cat = torch.stack(
        px, dim=0)  # (N, B, S, T+1) if !modified,(N, B, S, T) if modified.
    py_cat = torch.stack(py, dim=0)  # (N, B, S+1, T)
    px_tot = px_cat.sum(dim=0)  # (B, S, T+1)
    py_tot = py_cat.sum(dim=0)  # (B, S+1, T)

    if boundary is not None:
        assert boundary.dtype == torch.int64
        assert boundary.shape == (B, 4)
        for s_begin, t_begin, s_end, t_end in boundary.tolist():
            assert 0 <= s_begin <= s_end <= S
            assert 0 <= t_begin <= t_end <= T

    px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
    # The following assertions are for efficiency
    assert px_tot.ndim == 3
    assert py_tot.ndim == 3

    p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)

    # note, tot_probs is without grad.
    tot_probs = _k2.mutual_information_forward(px_tot, py_tot, boundary, p)

    # this is a kind of "fake gradient" that we use, in effect to compute
    # occupation probabilities.  The backprop will work regardless of the
    # actual derivative w.r.t. the total probs.
    ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)

    (px_grad,
     py_grad) = _k2.mutual_information_backward(px_tot, py_tot, boundary, p,
                                                ans_grad)

    px_grad = px_grad.reshape(1, B, -1)
    py_grad = py_grad.reshape(1, B, -1)
    px_cat = px_cat.reshape(N, B, -1)
    py_cat = py_cat.reshape(N, B, -1)
    # get rid of -inf, would generate nan on product with 0
    px_cat = px_cat.clamp(min=torch.finfo(px_cat.dtype).min)
    py_cat = py_cat.clamp(min=torch.finfo(py_cat.dtype).min)

    x_prods = _inner_product(px_grad, px_cat)  # (N, B)
    y_prods = _inner_product(py_grad, py_cat)  # (N, B)

    # If all the occupation counts were exactly 1.0 (i.e. no partial counts),
    # "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
    # will be more positive due to the difference between log-of-sum and
    # sum-of-log
    prods = x_prods + y_prods  # (N, B)
    with torch.no_grad():
        offset = tot_probs - prods.sum(dim=0)  # (B,)
    prods[0] += offset
    return prods  # (N, B)
Esempio n. 14
0
    def forward(self, p_yolo_tuple, targets, imgs_ts=None):
        ''' 只支持相同的anc数

        :param p_yolo_tuple: pconf pcls ptxywh
            pconf: torch.Size([3, 10647, 1])
            pcls: torch.Size([3, 10647, 3])
            ptxywh: torch.Size([3, 10647, 4])
        :param targets: list
            target['boxes'] = target['boxes'].to(device)
            target['labels'] = target['labels'].to(device)
            target['size'] = target['size']
            target['image_id'] = int
        :return:
        '''
        cfg = self.cfg
        pconf, pcls, ptxywh = p_yolo_tuple
        device = ptxywh.device
        batch, hwa, c = ptxywh.shape  # [3, 10647, 4]

        if cfg.MODE_TRAIN == 5:
            # yolo5 conf1 + label1 + goff_xywh4 + gxywh4 =10
            gdim = 12
        else:
            # conf-1, cls-num_class, txywh-4, weight-1, gltrb-4
            gdim = 1 + cfg.NUM_CLASSES + 4 + 1 + 4

        # h*w*anc

        gyolos = torch.empty((batch, hwa, gdim), device=device)

        # 匹配GT
        for i, target in enumerate(targets):  # batch遍历
            gboxes_ltrb_b = target['boxes']  # ltrb
            glabels_b = target['labels']
            ''' 可视化在里面 每层特图不一样'''

            if cfg.MODE_TRAIN == 5:
                gyolos[i] = fmatch4yolov5(gboxes_ltrb_b=gboxes_ltrb_b,
                                          glabels_b=glabels_b,
                                          dim=gdim,
                                          ptxywh_b=ptxywh[i],
                                          device=device,
                                          cfg=cfg,
                                          img_ts=imgs_ts[i],
                                          pconf_b=pconf[i])
            else:
                # gyolos[i] = fmatch4yolov3(gboxes_ltrb_b=gboxes_ltrb_b,
                #                           glabels_b=glabels_b,
                #                           dim=gdim,
                #                           ptxywh_b=ptxywh[i],
                #                           device=device, cfg=cfg,
                #                           img_ts=imgs_ts[i],
                #                           pconf_b=pconf[i])

                gyolos[i] = fmatch4yolov3_iou(gboxes_ltrb_b=gboxes_ltrb_b,
                                              glabels_b=glabels_b,
                                              dim=gdim,
                                              ptxywh_b=ptxywh[i],
                                              device=device,
                                              cfg=cfg,
                                              img_ts=imgs_ts[i],
                                              pconf_b=pconf[i],
                                              val_iou=0.3)

        # gyolos [3, 10647, 13] conf-1, cls-3, tbox-4, weight-1, gltrb-4  = 13
        s_ = 1 + cfg.NUM_CLASSES

        gconf = gyolos[:, :, 0]  # 正例使用1

        mask_pos_2d = gconf > 0  # 同维bool索引 忽略的-1不计
        mask_neg_2d = gconf == 0  # 忽略-1 不管
        nums_pos = (mask_pos_2d.sum(-1).to(
            torch.float)).clamp(min=torch.finfo(torch.float16).eps)

        # # 使用 IOU 作为 conf 解码pxywh 计算预测与 GT 的 iou 作为 gconf
        # with torch.no_grad():  # torch.Size([3, 40, 13, 13])
        #     gyolos = gyolos.view(batch, -1, gdim)  # 4d -> 3d [3, 13, 13, 5, 13] -> [3, 169*5, 13]
        #     # mask_pos_2d = gyolos[:, :, 0] == 1  # 前面已匹配,降维运算 [3, xx, 13] -> [3, xx]
        #
        #     gltrb = gyolos[:, :, -4:]  # [3, 169*5, 13] ->  [3, 169*5, 4]
        #     pltrb = boxes_decode4yolo3(ptxywh, cfg)
        #
        #     _pltrb = pltrb.view(-1, 4)
        #     _gltrb = gltrb.view(-1, 4)  # iou 只支持2位
        #     iou_p = bbox_iou4one(_pltrb, _gltrb, is_ciou=True)  # 一一对应IOU
        #     iou_p = iou_p.view(batch, -1)  # 匹配每批的IOU [nn,1] -> [batch,nn/batch]
        #
        #     '''可视化 匹配的预测框'''
        #     debug = False  # torch.isnan(loss_conf_pos)
        #     if debug:  # debug
        #         # d0, d1 = torch.where(mask_pos_2d)  # [3,845]
        #         for i in range(batch):
        #             from f_tools.pic.enhance.f_data_pretreatment4pil import f_recover_normalization4ts
        #             # img_ts = f_recover_normalization4ts(imgs_ts[i])
        #             # mask_ = d0 == i
        #             # _pltrb = _pltrb[d0[mask_], d1[mask_]].cpu()
        #             # _gbox_p = _gltrb[d0[mask_], d1[mask_]].cpu()
        #             _img_ts = imgs_ts[i].clone()
        #             _pltrb_show = pltrb[i][mask_pos[i]]
        #             _gltrb_show = gltrb[i][mask_pos[i]]
        #
        #             iou = bbox_iou4one(_pltrb_show, _gltrb_show)
        #             flog.debug('预测 iou %s', iou)
        #             _img_ts = f_recover_normalization4ts(_img_ts)
        #             f_show_od_ts4plt(_img_ts, gboxes_ltrb=_gltrb_show.detach().cpu()
        #                              , pboxes_ltrb=_pltrb_show.detach().cpu(), is_recover_size=True,
        #                              )
        #
        # gconf = iou_p  # 使用 iou赋值

        gyolos_pos = gyolos[mask_pos_2d]
        log_dict = {}

        if cfg.MODE_TRAIN == 5:
            # yolo5 conf1 + label1 + goff_xywh4 + gxywh4 +ancwh2 =12
            ''' ----------------cls损失 只计算正例---------------- '''
            pcls_sigmoid_pos = pcls[mask_pos_2d].sigmoid()  # 归一
            gcls_pos = gyolos_pos[:, 1:2]
            gcls_pos = labels2onehot4ts(gcls_pos - 1, cfg.NUM_CLASSES)
            _loss_val = x_bce(pcls_sigmoid_pos, gcls_pos, reduction="none")
            l_cls = _loss_val.sum(-1).mean()
            ''' ----------------box损失   iou----------------- '''
            ptxty_sigmoid_pos = ptxywh[mask_pos_2d][:, :2].sigmoid(
            ) * 2. - 0.5  # 格子偏移[-0.5 ~ 1.5]
            gxywh_pos = gyolos_pos[:, 6:10]
            _ancwh_pos = gyolos_pos[:, 10:12]

            ptwth_sigmoid_pos = (ptxywh[mask_pos_2d][:, 2:4].sigmoid() *
                                 2)**2 * _ancwh_pos  # [0~4]
            pxywh_pos = torch.cat([ptxty_sigmoid_pos, ptwth_sigmoid_pos], -1)

            iou_zg = bbox_iou4one_2d(xywh2ltrb4ts(pxywh_pos),
                                     xywh2ltrb4ts(gxywh_pos),
                                     is_giou=True)
            # iou_zg = bbox_iou4y(xywh2ltrb4ts(pzxywh), gltrb_pos_tx, GIoU=True)
            # print(iou_zg)
            l_reg = (1 - iou_zg).mean()
            ''' ----------------conf损失 ---------------- '''
            pconf_sigmoid = pconf.sigmoid().view(
                batch, -1)  # [3, 10647, 1] -> [3, 10647]
            gconf[mask_pos_2d] = iou_zg.detach().clamp(0)  # 使用IOU值修正正例值

            # ------------conf-mse ------------'''
            # _loss_val = F.mse_loss(pconf_sigmoid, gconf, reduction="none")
            # # _loss_val = F.binary_cross_entropy_with_logits(pconf_sigmoid, gconf, reduction="none")
            # l_conf_pos = ((_loss_val * mask_pos_2d).sum(-1) / nums_pos).mean() * 12
            # l_conf_neg = ((_loss_val * mask_neg_2d).sum(-1) / nums_pos).mean() * 2.5

            # pos_ = _loss_val[mask_pos_2d]
            # l_conf_pos = pos_.mean()
            # l_conf_neg = _loss_val[mask_neg_2d].mean()

            # ------------conf-focalloss ------------'''
            mask_ignore_2d = torch.logical_not(
                torch.logical_or(mask_pos_2d, mask_neg_2d))
            # l_pos, l_neg = focalloss(pconf_sigmoid, gconf, mask_pos=mask_pos_2d, mash_ignore=mask_ignore_2d,
            #                          is_debug=True)
            # l_conf_pos = (l_pos.sum(-1) / nums_pos).mean()
            # l_conf_neg = (l_neg.sum(-1) / nums_pos).mean()

            # ------------conf-ohem ------------'''
            _loss_val = x_bce(pconf_sigmoid, gconf, reduction="none")
            mask_neg_hard = f_ohem(_loss_val,
                                   nums_pos * 3,
                                   mask_pos=mask_pos_2d,
                                   mash_ignore=mask_ignore_2d)
            l_conf_pos = ((_loss_val * mask_pos_2d).sum(-1) / nums_pos).mean()
            l_conf_neg = ((_loss_val * mask_neg_hard).sum(-1) /
                          nums_pos).mean()
            ''' ---------------- loss完成 ----------------- '''
            l_total = l_conf_pos + l_conf_neg + l_cls + l_reg

            log_dict['l_total'] = l_total.item()
            log_dict['l_conf_pos'] = l_conf_pos.item()
            log_dict['l_conf_neg'] = l_conf_neg.item()
            log_dict['l_cls'] = l_cls.item()
            log_dict['l_reg'] = l_reg.item()

        else:
            ''' ----------------cls损失---------------- '''
            pcls_sigmoid = pcls.sigmoid()  # 归一
            gcls = gyolos[:, :, 1:s_]
            _loss_val = x_bce(pcls_sigmoid, gcls, reduction="none")
            l_cls = ((_loss_val.sum(-1) * mask_pos_2d).sum(-1) /
                     nums_pos).mean()
            ''' ----------------conf损失 ---------------- '''
            pconf_sigmoid = pconf.sigmoid().view(
                batch, -1)  # [3, 10647, 1] -> [3, 10647]

            # ------------conf-mse ------------'''
            # _loss_val = F.mse_loss(pconf_sigmoid, gconf, reduction="none")
            # _loss_val = F.binary_cross_entropy_with_logits(pconf_sigmoid, gconf, reduction="none")

            # l_conf_pos = ((_loss_val * mask_pos_2d).mean(-1)).mean() * 5.
            # l_conf_neg = ((_loss_val * mask_neg_2d).mean(-1)).mean() * 1.
            # l_conf_pos = ((_loss_val * mask_pos_2d).sum(-1) / nums_pos).mean() * 5.
            # l_conf_neg = ((_loss_val * mask_neg_2d).sum(-1) / nums_pos).mean() * 1.

            # l_conf_pos = _loss_val[mask_pos].mean() * 5
            # l_conf_neg = _loss_val[mask_neg].mean()

            # ------------conf-focalloss ------------'''
            mask_ignore_2d = torch.logical_not(
                torch.logical_or(mask_pos_2d, mask_neg_2d))
            # l_pos, l_neg = focalloss(pconf_sigmoid, gconf, mask_pos=mask_pos_2d, mash_ignore=mask_ignore_2d,
            #                          is_debug=True)
            # l_conf_pos = (l_pos.sum(-1) / nums_pos).mean()
            # l_conf_neg = (l_neg.sum(-1) / nums_pos).mean()

            # ------------conf-ohem ------------'''
            _loss_val = x_bce(pconf_sigmoid, gconf, reduction="none")
            mask_neg_hard = f_ohem(_loss_val,
                                   nums_pos * 3,
                                   mask_pos=mask_pos_2d,
                                   mash_ignore=mask_ignore_2d)
            l_conf_pos = ((_loss_val * mask_pos_2d).sum(-1) / nums_pos).mean()
            l_conf_neg = ((_loss_val * mask_neg_hard).sum(-1) /
                          nums_pos).mean()
            ''' ----------------box损失   xy采用bce wh采用mes----------------- '''
            # conf-1, cls-3, tbox-4, weight-1, gltrb-4  = 13
            weight = gyolos[:, :, s_ + 4]  # torch.Size([32, 845])
            ptxty_sigmoid = ptxywh[:, :, :2].sigmoid()  # 这个需要归一化
            ptwth = ptxywh[:, :, 2:4]
            gtxty = gyolos[:, :, s_:s_ + 2]
            gtwth = gyolos[:, :, s_ + 2:s_ + 4]

            _loss_val = x_bce(ptxty_sigmoid, gtxty, reduction="none")
            l_txty = ((_loss_val.sum(-1) * mask_pos_2d * weight).sum(-1) /
                      nums_pos).mean()
            _loss_val = F.mse_loss(ptwth, gtwth, reduction="none")
            l_twth = ((_loss_val.sum(-1) * mask_pos_2d * weight).sum(-1) /
                      nums_pos).mean()

            l_total = l_conf_pos + l_conf_neg + l_cls + l_txty + l_twth

            log_dict['l_total'] = l_total.item()
            log_dict['l_conf_pos'] = l_conf_pos.item()
            log_dict['l_conf_neg'] = l_conf_neg.item()
            log_dict['l_cls'] = l_cls.item()
            log_dict['l_xy'] = l_txty.item()
            log_dict['l_wh'] = l_twth.item()

        # log_dict['p_max'] = pconf.max().item()
        # log_dict['p_min'] = pconf.min().item()
        # log_dict['p_mean'] = pconf.mean().item()
        return l_total, log_dict
 def safe_zero_division(numerator: torch.Tensor,
                        denominator: torch.Tensor) -> torch.Tensor:
     eps: float = torch.finfo(numerator.dtype).tiny  # type: ignore
     return numerator / torch.clamp(denominator, min=eps)
Esempio n. 16
0
def _phase_congruency(x: torch.Tensor,
                      scales: int = 4,
                      orientations: int = 4,
                      min_length: int = 6,
                      mult: int = 2,
                      sigma_f: float = 0.55,
                      delta_theta: float = 1.2,
                      k: float = 2.0) -> torch.Tensor:
    r"""Compute Phase Congruence for a batch of greyscale images

    Args:
        x: Tensor. Shape :math:`(N, 1, H, W)`.
        scales: Number of wavelet scales
        orientations: Number of filter orientations
        min_length: Wavelength of smallest scale filter
        mult: Scaling factor between successive filters
        sigma_f: Ratio of the standard deviation of the Gaussian
            describing the log Gabor filter's transfer function
            in the frequency domain to the filter center frequency.
        delta_theta: Ratio of angular interval between filter orientations
            and the standard deviation of the angular Gaussian function
            used to construct filters in the freq. plane.
        k: No of standard deviations of the noise energy beyond the mean
            at which we set the noise threshold point, below which phase
            congruency values get penalized.

    Returns:
        Phase Congruency map with shape :math:`(N, H, W)`

    """
    EPS = torch.finfo(x.dtype).eps

    N, _, H, W = x.shape

    # Fourier transform
    filters = _construct_filters(x, scales, orientations, min_length, mult,
                                 sigma_f, delta_theta, k)
    recommended_torch_version = '1.8.0'
    if _version_tuple(
            torch.__version__) >= _version_tuple(recommended_torch_version):
        imagefft = torch.fft.fft2(x)
        filters_ifft = torch.fft.ifft2(filters)
        filters_ifft = filters_ifft.real * math.sqrt(H * W)
        even_odd = torch.view_as_real(torch.fft.ifft2(
            imagefft * filters)).view(N, orientations, scales, H, W, 2)
    else:
        imagefft = torch.rfft(x, 2, onesided=False)
        filters_ifft = torch.ifft(
            torch.stack([filters, torch.zeros_like(filters)], dim=-1), 2)[...,
                                                                          0]
        filters_ifft *= math.sqrt(H * W)
        even_odd = torch.ifft(imagefft * filters.unsqueeze(-1),
                              2).view(N, orientations, scales, H, W, 2)

    # Amplitude of even & odd filter response. An = sqrt(real^2 + imag^2)
    an = torch.sqrt(torch.sum(even_odd**2, dim=-1))

    # Take filter at scale 0 and sum spatially
    # Record mean squared filter value at smallest scale.
    # This is used for noise estimation.
    em_n = (filters.view(1, orientations, scales, H,
                         W)[:, :, :1, ...]**2).sum(dim=[-2, -1], keepdims=True)

    # Sum of even filter convolution results.
    sum_e = even_odd[..., 0].sum(dim=2, keepdims=True)

    # Sum of odd filter convolution results.
    sum_o = even_odd[..., 1].sum(dim=2, keepdims=True)

    # Get weighted mean filter response vector, this gives the weighted mean phase angle.
    x_energy = torch.sqrt(sum_e**2 + sum_o**2) + EPS

    mean_e = sum_e / x_energy
    mean_o = sum_o / x_energy

    # Now calculate An(cos(phase_deviation) - | sin(phase_deviation)) | by
    # using dot and cross products between the weighted mean filter response
    # vector and the individual filter response vectors at each scale.
    # This quantity is phase congruency multiplied by An, which we call energy.

    # Extract even and odd convolution results.
    even = even_odd[..., 0]
    odd = even_odd[..., 1]

    energy = (even * mean_e + odd * mean_o -
              torch.abs(even * mean_o - odd * mean_e)).sum(dim=2, keepdim=True)

    # Compensate for noise
    # We estimate the noise power from the energy squared response at the
    # smallest scale.  If the noise is Gaussian the energy squared will have a
    # Chi-squared 2DOF pdf.  We calculate the median energy squared response
    # as this is a robust statistic.  From this we estimate the mean.
    # The estimate of noise power is obtained by dividing the mean squared
    # energy value by the mean squared filter value

    abs_eo = torch.sqrt(torch.sum(even_odd[:, :, :1, ...]**2,
                                  dim=-1)).reshape(N, orientations, 1, 1,
                                                   H * W)
    median_e2n = torch.median(abs_eo**2, dim=-1, keepdim=True).values

    mean_e2n = -median_e2n / math.log(0.5)

    # Estimate of noise power.
    noise_power = mean_e2n / em_n

    # Now estimate the total energy^2 due to noise
    # Estimate for sum(An^2) + sum(Ai.*Aj.*(cphi.*cphj + sphi.*sphj))
    filters_ifft = filters_ifft.view(1, orientations, scales, H, W)

    sum_an2 = torch.sum(filters_ifft**2, dim=-3, keepdim=True)

    sum_ai_aj = torch.zeros(N, orientations, 1, H, W).to(x)
    for s in range(scales - 1):
        sum_ai_aj = sum_ai_aj + (filters_ifft[:, :, s:s + 1] *
                                 filters_ifft[:, :, s + 1:]).sum(dim=-3,
                                                                 keepdim=True)

    sum_an2 = torch.sum(sum_an2, dim=[-1, -2], keepdim=True)
    sum_ai_aj = torch.sum(sum_ai_aj, dim=[-1, -2], keepdim=True)

    noise_energy2 = 2 * noise_power * sum_an2 + 4 * noise_power * sum_ai_aj

    # Rayleigh parameter
    tau = torch.sqrt(noise_energy2 / 2)

    # Expected value of noise energy
    noise_energy = tau * math.sqrt(math.pi / 2)
    moise_energy_sigma = torch.sqrt((2 - math.pi / 2) * tau**2)

    # Noise threshold
    T = noise_energy + k * moise_energy_sigma

    # The estimated noise effect calculated above is only valid for the PC_1 measure.
    # The PC_2 measure does not lend itself readily to the same analysis.  However
    # empirically it seems that the noise effect is overestimated roughly by a factor
    # of 1.7 for the filter parameters used here.

    # Empirical rescaling of the estimated noise effect to suit the PC_2 phase congruency measure
    T = T / 1.7

    # Apply noise threshold
    energy = torch.max(energy - T, torch.zeros_like(T))

    eps = torch.finfo(energy.dtype).eps
    energy_all = energy.sum(dim=[1, 2]) + eps
    an_all = an.sum(dim=[1, 2]) + eps
    result_pc = energy_all / an_all
    return result_pc.unsqueeze(1)
Esempio n. 17
0
def make_tensor(
    *shape: Union[int, torch.Size, List[int], Tuple[int, ...]],
    dtype: torch.dtype,
    device: Union[str, torch.device],
    low: Optional[float] = None,
    high: Optional[float] = None,
    requires_grad: bool = False,
    noncontiguous: bool = False,
    exclude_zero: bool = False
) -> torch.Tensor:
    r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
    values uniformly drawn from ``[low, high)``.

    If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
    finite values then they are clamped to the lowest or highest representable finite value, respectively.
    If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
    which depend on :attr:`dtype`.

    +---------------------------+------------+----------+
    | ``dtype``                 | ``low``    | ``high`` |
    +===========================+============+==========+
    | boolean type              | ``0``      | ``2``    |
    +---------------------------+------------+----------+
    | unsigned integral type    | ``0``      | ``10``   |
    +---------------------------+------------+----------+
    | signed integral types     | ``-9``     | ``10``   |
    +---------------------------+------------+----------+
    | floating types            | ``-9``     | ``9``    |
    +---------------------------+------------+----------+
    | complex types             | ``-9``     | ``9``    |
    +---------------------------+------------+----------+

    Args:
        shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor.
        dtype (:class:`torch.dtype`): The data type of the returned tensor.
        device (Union[str, torch.device]): The device of the returned tensor.
        low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is
            clamped to the least representable finite value of the given dtype. When ``None`` (default),
            this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
        high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is
            clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value
            is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
        requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``.
        noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is
            ignored if the constructed tensor has fewer than two elements.
        exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value
            depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating
            point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the
            :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number
            whose real and imaginary parts are both the smallest positive normal number representable by the complex
            type. Default ``False``.

    Raises:
        ValueError: if ``requires_grad=True`` is passed for integral `dtype`
        ValueError: If ``low > high``.
        ValueError: If either :attr:`low` or :attr:`high` is ``nan``.
        TypeError: If :attr:`dtype` isn't supported by this function.

    Examples:
        >>> from torch.testing import make_tensor
        >>> # Creates a float tensor with values in [-1, 1)
        >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
        tensor([ 0.1205, 0.2282, -0.6380])
        >>> # Creates a bool tensor on CUDA
        >>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
        tensor([[False, False],
                [False, True]], device='cuda:0')
    """
    def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype):
        """
        Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required.
        """
        def clamp(a, l, h):
            return min(max(a, l), h)

        low = low if low is not None else default_low
        high = high if high is not None else default_high

        # Checks for error cases
        if low != low or high != high:
            raise ValueError("make_tensor: one of low or high was NaN!")
        if low > high:
            raise ValueError("make_tensor: low must be weakly less than high!")

        low = clamp(low, lowest, highest)
        high = clamp(high, lowest, highest)

        if dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
            return math.floor(low), math.ceil(high)

        return low, high

    if len(shape) == 1 and isinstance(shape[0], collections.abc.Sequence):
        shape = shape[0]  # type: ignore[assignment]
    shape = cast(Tuple[int, ...], tuple(shape))

    _integral_types = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
    _floating_types = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
    _complex_types = [torch.complex32, torch.complex64, torch.complex128]
    if requires_grad and dtype not in _floating_types and dtype not in _complex_types:
        raise ValueError("make_tensor: requires_grad must be False for integral dtype")

    if dtype is torch.bool:
        result = torch.randint(0, 2, shape, device=device, dtype=dtype)  # type: ignore[call-overload]
    elif dtype is torch.uint8:
        ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
        low, high = cast(Tuple[int, int], _modify_low_high(low, high, ranges[0], ranges[1], 0, 10, dtype))
        result = torch.randint(low, high, shape, device=device, dtype=dtype)   # type: ignore[call-overload]
    elif dtype in _integral_types:
        ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
        low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 10, dtype)
        result = torch.randint(low, high, shape, device=device, dtype=dtype)  # type: ignore[call-overload]
    elif dtype in _floating_types:
        ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max)
        low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
        rand_val = torch.rand(shape, device=device, dtype=dtype)
        result = high * rand_val + low * (1 - rand_val)
    elif dtype in _complex_types:
        float_dtype = complex_to_corresponding_float_type_map[dtype]
        ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max)
        low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
        real_rand_val = torch.rand(shape, device=device, dtype=float_dtype)
        imag_rand_val = torch.rand(shape, device=device, dtype=float_dtype)
        real = high * real_rand_val + low * (1 - real_rand_val)
        imag = high * imag_rand_val + low * (1 - imag_rand_val)
        result = torch.complex(real, imag)
    else:
        raise TypeError(f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
                        " To request support, file an issue at: https://github.com/pytorch/pytorch/issues")

    if noncontiguous and result.numel() > 1:
        result = torch.repeat_interleave(result, 2, dim=-1)
        result = result[..., ::2]

    if exclude_zero:
        if dtype in _integral_types or dtype is torch.bool:
            replace_with = torch.tensor(1, device=device, dtype=dtype)
        elif dtype in _floating_types:
            replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype)
        else:  # dtype in _complex_types:
            float_dtype = complex_to_corresponding_float_type_map[dtype]
            float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype)
            replace_with = torch.complex(float_eps, float_eps)
        result[result == 0] = replace_with

    if dtype in _floating_types + _complex_types:
        result.requires_grad = requires_grad

    return result
Esempio n. 18
0
 def get_loss(self, y_pred, y_true, *args, **kwargs):
     if isinstance(self.criterion_, torch.nn.NLLLoss):
         eps = torch.finfo(y_pred.dtype).eps
         y_pred = torch.log(y_pred + eps)
     return super().get_loss(y_pred, y_true, *args, **kwargs)
Esempio n. 19
0
def get_rnnt_logprobs_smoothed(
    lm: Tensor,
    am: Tensor,
    symbols: Tensor,
    termination_symbol: int,
    lm_only_scale: float = 0.1,
    am_only_scale: float = 0.1,
    boundary: Optional[Tensor] = None,
    modified: bool = False,
) -> Tuple[Tensor, Tensor]:
    """
    Reduces RNN-T problem (the simple case, where joiner network is just
    addition), to a compact, standard form that can then be given
    (with boundaries) to mutual_information_recursion().
    This version allows you to make the loss-function one of the form::

          lm_only_scale * lm_probs +
          am_only_scale * am_probs +
          (1-lm_only_scale-am_only_scale) * combined_probs

    where lm_probs and am_probs are the probabilities given the lm and acoustic
    model independently.

    This function is called from
    :func:`rnnt_loss_smoothed`, but may be useful for other purposes.

    Args:
      lm:
        Language model part of un-normalized logprobs of symbols, to be added to
        acoustic model part before normalizing.  Of shape::

           [B][S+1][C]

        where B is the batch size, S is the maximum sequence length of
        the symbol sequence, possibly including the EOS symbol; and
        C is size of the symbol vocabulary, including the termination/next-frame
        symbol.
        Conceptually, lm[b][s] is a vector of length [C] representing the
        "language model" part of the un-normalized logprobs of symbols,
        given all symbols *earlier than* s in the sequence.  The reason
        we still need this for position S is that we may still be emitting
        the termination/next-frame symbol at this point.
      am:
        Acoustic-model part of un-normalized logprobs of symbols, to be added
        to language-model part before normalizing.  Of shape::

           [B][T][C]

        where B is the batch size, T is the maximum sequence length of
        the acoustic sequences (in frames); and C is size of the symbol
        vocabulary, including the termination/next-frame symbol.  It reflects
        the "acoustic" part of the probability of any given symbol appearing
        next on this frame.
      symbols:
        A LongTensor of shape [B][S], containing the symbols at each position
        of the sequence.
      termination_symbol:
        The identity of the termination symbol, must be in {0..C-1}
      lm_only_scale:
        the scale on the "LM-only" part of the loss.
      am_only_scale:
        the scale on the "AM-only" part of the loss, for which we use
        an "averaged" LM (averaged over all histories, so effectively unigram).
      boundary:
        a optional LongTensor of shape [B, 4] with elements interpreted as
        [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
        [0, 0, S, T]
        if boundary is not supplied.
        Most likely you will want begin_symbol and begin_frame to be zero.
      modified: if True, each time a real symbol is consumed a frame will
        also be consumed, so at most 1 symbol can appear per frame.
    Returns:
        (px, py) (the names are quite arbitrary).
           px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
           py: logprobs, of shape [B][S+1][T]

        in the recursion::

          p[b,0,0] = 0.0
          if !modified:
             p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
                                p[b,s,t-1] + py[b,s,t-1])
          if modified:
             p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
                                p[b,s,t-1] + py[b,s,t-1])
          .. where p[b][s][t] is the "joint score" of the pair of subsequences
          of length s and t respectively.  px[b][s][t] represents the
          probability of extending the subsequences of length (s,t) by one in
          the s direction, given the particular symbol, and py[b][s][t]
          represents the probability of extending the subsequences of length
          (s,t) by one in the t direction,
          i.e. of emitting the termination/next-frame symbol.

          px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
          we cannot emit any symbols.  This is simply a way of incorporating
          the probability of the termination symbol on the last frame.
    """
    assert lm.ndim == 3
    assert am.ndim == 3
    assert lm.shape[0] == am.shape[0]
    assert lm.shape[2] == am.shape[2]
    (B, T, C) = am.shape
    S = lm.shape[1] - 1
    assert symbols.shape == (B, S)

    # Caution: some parts of this code are a little less clear than they could
    # be due to optimizations.  In particular it may not be totally obvious that
    # all of the logprobs here are properly normalized.  We test that
    # this code is invariant to adding constants in the appropriate ways.

    # subtracting am_max and lm_max is to ensure the probs are in a good range
    # to do exp() without causing underflow or overflow.
    am_max, _ = torch.max(am, dim=2, keepdim=True)  # am_max: [B][T][1]
    lm_max, _ = torch.max(lm, dim=2, keepdim=True)  # lm_max: [B][S+1][1]
    am_probs = (am - am_max).exp()  # [B][T][C]
    lm_probs = (lm - lm_max).exp()  # [B][S+1][C]
    # normalizers: [B][S+1][T]
    normalizers = (torch.matmul(lm_probs, am_probs.transpose(1, 2)) +
                   torch.finfo(lm_probs.dtype).tiny).log()

    # normalizer per frame, if we take only the LM probs by themselves
    lmonly_normalizers = lm_probs.sum(
        dim=2, keepdim=True)  # lmonly_normalizers: [B][S+1][1]
    unigram_lm = (
        torch.mean(lm_probs / lmonly_normalizers, dim=(0, 1), keepdim=True) +
        torch.finfo(lm_probs.dtype).tiny)  # [1][1][C]
    amonly_normalizers = (torch.mv(am_probs.reshape(
        -1, C), unigram_lm.reshape(C)).reshape(B, T, 1).log() + am_max
                          )  # [B][T][1]
    amonly_normalizers = amonly_normalizers.transpose(1, 2)  # [B][1][T]
    unigram_lm = unigram_lm.log()
    lmonly_normalizers = (
        lmonly_normalizers.log() + lm_max
    )  # [B][S+1][1], log-normalizer, used for LM-only part of prob.

    # add lm_max and am_max to normalizers, to make it as if we had not
    # subtracted am_max and lm_max above.
    normalizers = normalizers + lm_max + am_max.transpose(1, 2)  # [B][S+1][T]

    # px is the probs of the actual symbols (not yet normalized)..
    px_am = torch.gather(
        am.unsqueeze(1).expand(B, S, T, C),
        dim=3,
        index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
    ).squeeze(-1)  # [B][S][T]

    if not modified:
        px_am = torch.cat(
            (
                px_am,
                torch.full(
                    (B, S, 1),
                    float("-inf"),
                    device=px_am.device,
                    dtype=px_am.dtype,
                ),
            ),
            dim=2,
        )  # now: [B][S][T+1], index [:,:,T] has -inf..

    px_lm = torch.gather(lm[:, :S], dim=2,
                         index=symbols.unsqueeze(-1))  # [B][S][1]
    px_lm_unigram = torch.gather(unigram_lm.expand(B, S, C),
                                 dim=2,
                                 index=symbols.unsqueeze(-1))  # [B][S][1]

    px = px_am + px_lm  # [B][S][T+1] if not modified, [B][S][T] if modified
    px[:, :, :T] -= normalizers[:, :S, :]  # px: [B][S][T+1] or [B][S][T]

    px_amonly = (px_am + px_lm_unigram
                 )  # [B][S][T+1] if !modified; [B][S][T] if modified.
    px_amonly[:, :, :T] -= amonly_normalizers
    px_lmonly = px_lm - lmonly_normalizers[:, :S, :]

    # py is the probs of termination symbols, of shape [B][S+1][T]
    py_am = am[:, :, termination_symbol].unsqueeze(1)  # [B][1][T]
    py_lm = lm[:, :, termination_symbol].unsqueeze(2)  # [B][S+1][1]
    py = py_am + py_lm - normalizers

    py_lm_unigram = unigram_lm[0][0][
        termination_symbol]  # scalar, normalized..
    py_amonly = py_am + py_lm_unigram - amonly_normalizers  # [B][S+1][T]
    py_lmonly = py_lm - lmonly_normalizers  # [B][S+1][T]

    combined_scale = 1.0 - lm_only_scale - am_only_scale

    # We need to avoid exact zeros in the scales because otherwise multiplying
    # -inf by zero generates nan.
    if lm_only_scale == 0.0:
        lm_only_scale = 1.0e-20
    if am_only_scale == 0.0:
        am_only_scale = 1.0e-20

    px_interp = (px * combined_scale + px_lmonly * lm_only_scale +
                 px_amonly * am_only_scale)
    py_interp = (py * combined_scale + py_lmonly * lm_only_scale +
                 py_amonly * am_only_scale)

    if not modified:
        px_interp = fix_for_boundary(px_interp, boundary)

    return (px_interp, py_interp)
Esempio n. 20
0
"""Beamformer module."""
from typing import List
from typing import Optional
from typing import Union

import numpy as np
import torch
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor

EPS = torch.finfo(torch.double).eps


def complex_norm(c: ComplexTensor) -> torch.Tensor:
    return torch.sqrt((c.real**2 + c.imag**2).sum(dim=-1, keepdim=True) + EPS)


def get_rtf(
    psd_speech: ComplexTensor,
    psd_noise: ComplexTensor,
    reference_vector: Union[int, torch.Tensor, None] = None,
    iterations: int = 3,
    use_torch_solver: bool = True,
) -> ComplexTensor:
    """Calculate the relative transfer function (RTF) using the power method.

    Algorithm:
        1) rtf = reference_vector
        2) for i in range(iterations):
             rtf = (psd_noise^-1 @ psd_speech) @ rtf
             rtf = rtf / ||rtf||_2  # this normalization can be skipped
def neg_inf(dtype):
    return torch.finfo(dtype).min
Esempio n. 22
0
def _clipped_sigmoid(x):
    finfo = torch.finfo(x.dtype)
    return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps)
Esempio n. 23
0
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max
Esempio n. 24
0
 def _inverse(self, y):
     finfo = torch.finfo(y.dtype)
     y = y.clamp(min=finfo.tiny, max=1. - finfo.eps)
     return y.log() - (-y).log1p()
Esempio n. 25
0
def update_model(
    optimizer: optim.Optimizer,
    scaler: amp.grad_scaler.GradScaler,
    buffer: Buffer,
    state: TSP2OPTState,
    done: bool,
    epoch: int,
    count: int,
    learn_count: int,
    global_step: int,
    logger: SummaryWriter,
    args,
):

    rewards = torch.stack(buffer.rewards, dim=0)  # [horizon, batch_size, 1]
    returns = discounted_return(rewards, args.gamma,
                                count)  # [horizon, batch_size, 1]
    if not args.no_norm_return:
        r_mean = returns.mean()
        r_std = returns.std()
        eps = torch.finfo(torch.float).eps  # small number to avoid div/0
        returns = (returns - r_mean) / (r_std + eps)
    values = torch.stack(buffer.values, dim=0)  # [horizon, batch_size, 1]
    advantages = (returns - values).detach()  # [horizon, batch_size, 1]

    logps = torch.stack(buffer.log_probs,
                        dim=0)  # [horizon, batch_size, 2, graph_size]
    actions = torch.stack(buffer.actions, dim=0)  # [horizon, batch_size, 2, 1]
    log_likelihood = logps.gather(-1, actions).squeeze(
        -1)  # [horizon, batch_size, 2]
    log_likelihood = log_likelihood.mean(2).unsqueeze(
        2)  # [horizon, batch_size, 1]

    entropies = log_p_to_entropy(logps).mean(2).unsqueeze(
        2)  # [horizon, batch_size, 1]

    p_loss = (-log_likelihood * advantages).mean()
    v_loss = args.value_beta * (returns - values).pow(2).mean()
    e_loss = (0.9**(epoch + 1)) * args.entropy_beta * entropies.sum(0).mean()
    r_loss = -e_loss + v_loss
    loss = p_loss + r_loss

    optimizer.zero_grad()
    scaler.scale(p_loss).backward(retain_graph=True)
    # scaler.unscale_(optimizer)
    grad_norms = clip_grad_norms(
        optimizer.param_groups)  #, args.max_grad_norm)
    scaler.scale(r_loss).backward(retain_graph=False)
    scaler.step(optimizer)
    scaler.update()

    buffer.clear_buffer()
    log_values(
        cost=state.best_tour_len,
        grad_norms=grad_norms,
        done=done,
        epoch=epoch,
        global_step=global_step,
        learn_count=learn_count,
        p_loss=p_loss,
        v_loss=v_loss,
        e_loss=e_loss,
        loss=loss,
        returns=returns.mean(),
        value=values.mean(),
        entropy=entropies.detach().mean(),
        logger=logger,
        args=args,
    )

    learn_count += 1

    return learn_count
Esempio n. 26
0
File: const.py Progetto: yt752/aps
# Copyright 2020 Jian Wu
# License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
Give pre-defined values
"""

import math
import numpy as np
import torch as th

IGNORE_ID = -1
NEG_INF = th.finfo(th.float32).min
MATH_PI = math.pi
EPSILON = np.finfo(np.float32).eps
MAX_INT16 = np.iinfo(np.int16).max
UNK_TOKEN = "<unk>"
BLK_TOKEN = "<b>"
Esempio n. 27
0
 def clamp_finite(self):
     finfo = torch.finfo(self.data.dtype)
     data = self.data.clamp(min=finfo.min, max=finfo.max)
     return Tensor(data, self.inputs, self.dtype)
Esempio n. 28
0
    def forward(
        self,
        query: torch.Tensor,
        glimpse_K: torch.Tensor,
        glimpse_V: torch.Tensor,
        logit_K: torch.Tensor,
        attn_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        r"""
        Shape:
            Inputs:
            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
            the embedding dimension.
            - glimpse_K: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
            the embedding dimension.
            - glimpse_V: :math:`(S, N, E)`
            - logit_K: :math:`(N, S, E)`
            - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
            S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight.

            Outputs:
            - log_prob: :math:`(N, L, S)` where N is the batch size,
            L is the target sequence length, S is the source sequence length.
        """
        num_heads = self.num_heads
        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        src_len, _, embed_dim = glimpse_K.size()
        assert embed_dim == self.embed_dim
        assert list(glimpse_K.size()) == list(glimpse_V.size())
        assert list(logit_K.size()) == [
            bsz, src_len, embed_dim
        ], f"{logit_K.size()} - {[bsz, src_len, embed_dim]}"
        head_dim = embed_dim // num_heads
        assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        if attn_mask is not None:
            assert attn_mask.dtype == torch.bool, "Only bool types are supported for attn_mask, not {}".format(
                attn_mask.dtype)
            assert attn_mask.dim(
            ) == 3, "attn_mask's dimension {} is not supported".format(
                attn_mask.dim())
            if list(attn_mask.size()) == [
                    bsz,
                    tgt_len,
                    src_len,
            ]:
                heads_mask = attn_mask.expand(num_heads, *attn_mask.size())
            else:
                raise RuntimeError(
                    "The size of the 3D attn_mask is not correct.")

        # (n_heads, batch_size, target_len, head_dim)
        glimpse_Q = query.view(tgt_len, bsz, num_heads,
                               head_dim).permute(2, 1, 0, 3)
        # (n_heads, batch_size, source_len, head_dim)
        glimpse_K = glimpse_K.view(src_len, bsz, num_heads,
                                   head_dim).permute(2, 1, 0, 3)
        glimpse_V = glimpse_V.view(src_len, bsz, num_heads,
                                   head_dim).permute(2, 1, 0, 3)

        compatibility = torch.matmul(glimpse_Q, glimpse_K.transpose(
            -2, -1)) / math.sqrt(glimpse_Q.size(-1))
        assert list(compatibility.size()) == [num_heads, bsz, tgt_len, src_len]

        if attn_mask is not None:
            assert attn_mask.dtype == torch.bool
            compatibility.masked_fill_(heads_mask,
                                       torch.finfo(compatibility.dtype).min)

        heads = torch.matmul(torch.softmax(compatibility, dim=-1), glimpse_V)
        assert list(heads.size()) == [num_heads, bsz, tgt_len, head_dim]

        glimpse = self.glimpse_proj(
            heads.permute(1, 2, 0,
                          3).contiguous().view(bsz, tgt_len, embed_dim))

        final_Q = glimpse
        assert list(final_Q.size()) == [bsz, tgt_len, embed_dim]
        logits = torch.matmul(final_Q, logit_K.transpose(-2, -1)) / math.sqrt(
            final_Q.size(-1))
        assert list(logits.size()) == [bsz, tgt_len, src_len]

        if self.tanh_clipping > 0:
            logits = torch.tanh(logits) * self.tanh_clipping
        if attn_mask is not None:
            assert attn_mask.dtype == torch.bool
            logits.masked_fill_(attn_mask, torch.finfo(logits.dtype).min)

        log_prob = F.log_softmax(logits, dim=-1)
        if torch.isnan(log_prob).any():
            torch.save(
                {
                    "glimpse_Q": glimpse_Q,
                    "glimpse_K": glimpse_K,
                    "glimpse_V": glimpse_V,
                    "compatibility": compatibility,
                    "heads": heads,
                    "final_Q": final_Q,
                    "logits": logits,
                }, "nan-tensor.pt")
            assert not torch.isnan(log_prob).any()

        return log_prob.squeeze(1)
Esempio n. 29
0
def _safesub(x, y):
    try:
        return x + -y.clamp(max=torch.finfo(y.dtype).max)
    except TypeError:
        return x + -y.clamp(max=torch.iinfo(y.dtype).max)
Esempio n. 30
0
    def forward(self, pyolos, targets, imgs_ts=None):
        '''

        :param pyolos: torch.Size([32, 7, 13, 13]) cls-3,box-4
        :param targets:
        :param imgs_ts:
        :return:
        '''
        cfg = self.cfg
        device = pyolos.device
        batch, c, h, w = pyolos.shape
        pyolos = pyolos.view(batch, c, -1).permute(0, 2, 1)

        # cls-num_class, txywh-4, weight-1, gltrb-4
        gdim = cfg.NUM_CLASSES + 4 + 1 + 4
        gyolos = torch.empty((batch, h, w, gdim),
                             device=device)  # 每批会整体更新这里不需要赋0

        for i, target in enumerate(targets):  # batch遍历
            gboxes_ltrb_b = target['boxes']  # ltrb
            glabels_b = target['labels']

            gyolos[i] = fmatch4yolov1(
                gboxes_ltrb_b=gboxes_ltrb_b,
                glabels_b=glabels_b,
                grid=h,  # 7
                gdim=gdim,
                device=device,
                img_ts=imgs_ts[i],
                cfg=cfg,
                use_conf=False)
            '''可视化验证'''
            # if cfg.IS_VISUAL:
            #     # conf-1, cls-1, box-4, weight-1
            #     gyolo_test = gyolos[i].clone()  # torch.Size([32, 13, 13, 9])
            #     gyolo_test = gyolo_test.view(-1, gdim)
            #     gconf_one = gyolo_test[:, 0]
            #     mask_pos = gconf_one == 1  # [169]
            #
            #     # torch.Size([169, 4])
            #     gtxywh = gyolo_test[:, 1 + cfg.NUM_CLASSES:1 + cfg.NUM_CLASSES + 4]
            #
            #     # 这里是修复是 xy
            #     _xy_grid = gtxywh[:, :2] + f_mershgrid(h, w, is_rowcol=False).to(device)
            #     hw_ts = torch.tensor((h, w), device=device)
            #     gtxywh[:, :2] = torch.true_divide(_xy_grid, hw_ts)
            #     gtxywh = gtxywh[mask_pos]
            #     gtxywh[:, 2:4] = torch.exp(gtxywh[:, 2:]) / cfg.IMAGE_SIZE[0]
            #
            #     from f_tools.pic.enhance.f_data_pretreatment4pil import f_recover_normalization4ts
            #     img_ts = f_recover_normalization4ts(imgs_ts[i])
            #     from torchvision.transforms import functional as transformsF
            #     img_pil = transformsF.to_pil_image(img_ts).convert('RGB')
            #     import numpy as np
            #     img_np = np.array(img_pil)
            #     f_show_od_np4plt(img_np, gboxes_ltrb=gboxes_ltrb_b.cpu()
            #                      , pboxes_ltrb=xywh2ltrb(gtxywh.cpu()), is_recover_size=True,
            #                      grids=(h, w))

        # [32, 13, 13, 7] -> torch.Size([32, 169, 12])
        gyolos = gyolos.view(batch, -1, gdim)  # h*w
        gcls = gyolos[:, :, 0:cfg.NUM_CLASSES]  # torch.Size([5, 169])
        mask_pos_3d = gcls > 0  # torch.Size([32, 169, 3])
        mask_neg_3d = gcls == 0
        # [32, 169, 3] -> [32, 169]
        mask_pos_2d = torch.any(mask_pos_3d, dim=-1)
        # mask_pos = gconf == 1  # yolo1 gt 写死是1

        nums_pos = (mask_pos_2d.sum(-1).to(
            torch.float)).clamp(min=torch.finfo(torch.float16).eps)
        pyolos_pos = pyolos[
            mask_pos_2d]  # torch.Size([32, 169, 13]) -> [nn, 13]
        gyolos_pos = gyolos[
            mask_pos_2d]  # torch.Size([32, 169, 13]) -> [nn, 13]
        ''' ---------------- 类别损失 ---------------- '''
        # cls-num_class, txywh-4, weight-1, gltrb-4
        pcls_sigmoid = pyolos[:, :, 0:cfg.NUM_CLASSES].sigmoid()
        gcls = gyolos[:, :, 0:cfg.NUM_CLASSES]  # torch.Size([32, 169, 3])
        # 正反比 1:169*3
        # _loss_val = x_bce(pcls_sigmoid, gcls, reduction="none")  # torch.Size([46, 3])
        # l_cls_pos = ((_loss_val * mask_pos_3d).sum(-1).sum(-1) / nums_pos).mean()
        # l_cls_neg = ((_loss_val * mask_neg_3d).sum(-1).sum(-1) / nums_pos).mean()

        # ------------ conf-mse ------------''' 666666
        # _loss_val = F.mse_loss(pconf_sigmoid, gconf, reduction="none")  # 用MSE效果更好
        # _loss_val = x_bce(pconf_sigmoid, gconf, reduction="none")
        # l_conf_pos = ((_loss_val * mask_pos_3d).sum(-1) / nums_pos).mean() * cfg.LOSS_WEIGHT[0]
        # l_conf_neg = ((_loss_val * mask_neg_3d).sum(-1) / nums_pos).mean() * cfg.LOSS_WEIGHT[1]

        # ------------ conf_ohem  ap26_26 ------------'''
        # _loss_val = x_bce(pconf_sigmoid, gconf)
        # mask_ignore = torch.logical_not(torch.logical_or(mask_pos, mask_neg))
        # mask_neg_hard = f_ohem(_loss_val, nums_pos * 3, mask_pos=mask_pos, mash_ignore=mask_ignore)
        # l_conf_pos = ((_loss_val * mask_pos).sum(-1) / nums_pos).mean() * 3  # 正例越多反例越多
        # l_conf_neg = ((_loss_val * mask_neg_hard).sum(-1) / nums_pos).mean() * 3

        # ------------ focalloss   ------------
        l_pos, l_neg = focalloss(pcls_sigmoid,
                                 gcls,
                                 mask_pos=mask_pos_2d,
                                 is_debug=True,
                                 alpha=0.75)
        l_cls_pos = (l_pos.sum(-1).sum(-1) / nums_pos).mean() * 30
        l_cls_neg = (l_neg.sum(-1).sum(-1) / nums_pos).mean() * 30
        ''' ----------------回归损失   xy采用bce wh采用mes----------------- '''
        # ------------ mse+bce   ------------ 666666
        # conf-1, cls-num_class, txywh-4, weight-1, gltrb-4
        # ptxty_sigmoid_pos = pyolos_pos[:, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 2].sigmoid()  # 这个需要归一化
        # ptwth_pos = pyolos_pos[:, cfg.NUM_CLASSES + 2:cfg.NUM_CLASSES + 4]
        #
        # # cls-num_class, txywh-4, weight-1, gltrb-4
        # # id = cfg.NUM_CLASSES + 4 +1 -1
        # weight_pos = gyolos_pos[:, cfg.NUM_CLASSES + 4]  # torch.Size([32, 845])
        # gtxty_pos = gyolos_pos[:, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 2]  # [nn]
        # gtwth_pos = gyolos_pos[:, cfg.NUM_CLASSES + 2:cfg.NUM_CLASSES + 4]
        #
        # _loss_val = x_bce(ptxty_sigmoid_pos, gtxty_pos, reduction="none")
        # l_txty = (_loss_val.sum(-1) * weight_pos).mean()
        # _loss_val = F.mse_loss(ptwth_pos, gtwth_pos, reduction="none")
        # l_twth = (_loss_val.sum(-1) * weight_pos).mean()

        # ------------ iou损失   ------------
        # 解码pxywh 计算预测与 GT 的 iou 作为 gconf
        ptxywh_pos = pyolos[:, :, cfg.NUM_CLASSES:cfg.NUM_CLASSES + 4]
        # 这个是批量解码 3D 故解码出来再筛选
        zltrb_pos = boxes_decode4yolo1(ptxywh_pos, h, h, cfg)[mask_pos_2d]
        gltrb_pos = gyolos_pos[:, cfg.NUM_CLASSES + 4 + 1:cfg.NUM_CLASSES + 4 +
                               1 + 4]
        iou_zg = bbox_iou4one_2d(zltrb_pos, gltrb_pos, is_ciou=True)
        l_reg = (1 - iou_zg).mean()
        ''' ---------------- loss完成 ----------------- '''
        # loss_total = l_cls_pos + l_cls_neg + l_txty + l_twth
        loss_total = l_cls_pos + l_cls_neg + l_reg

        log_dict = {}
        log_dict['l_total'] = loss_total.item()
        log_dict['l_cls_pos'] = l_cls_pos.item()
        log_dict['l_cls_neg'] = l_cls_neg.item()
        log_dict['l_reg'] = l_reg.item()
        # log_dict['l_xy'] = l_txty.item()
        # log_dict['l_wh'] = l_twth.item()

        log_dict['p_max'] = pcls_sigmoid.max().item()
        log_dict['p_min'] = pcls_sigmoid.min().item()
        log_dict['p_mean'] = pcls_sigmoid.mean().item()
        return loss_total, log_dict