Пример #1
0
 def normalize(self, x, ax):
     q99 = torch.nanquantile(x, 0.99, axis=ax, keepdims=True)
     q01 = torch.nanquantile(x, 0.01, axis=ax, keepdims=True)
     den = (q99 - q01)
     den[den == 0] = 1
     x = (x - q01) / den
     x[x < 0] = 0
     return x
Пример #2
0
def quantile_filter_inplace(img, ksize=3, quantile=.5):
    if ksize % 2 != 0:
        pad_size = (int(ksize / 2),)*4
    else:
        pad_size = (int(ksize / 2) - 1, int(ksize / 2)) * 2
    img_unf = unfold(pad(img, pad_size, value=np.nan), ksize)
    flat_img = img.view(*img.size()[:2], img.size(2) * img.size(3))
    torch.nanquantile(img_unf, quantile, 1, keepdim=True, out=flat_img)
Пример #3
0
 def reduction_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     return (
         torch.argmax(a),
         torch.argmin(a),
         torch.amax(a),
         torch.amin(a),
         torch.aminmax(a),
         torch.all(a),
         torch.any(a),
         torch.max(a),
         torch.min(a),
         torch.dist(a, b),
         torch.logsumexp(a, 0),
         torch.mean(a),
         torch.nanmean(a),
         torch.median(a),
         torch.nanmedian(a),
         torch.mode(a),
         torch.norm(a),
         torch.nansum(a),
         torch.prod(a),
         torch.quantile(a, torch.tensor([0.25, 0.5, 0.75])),
         torch.nanquantile(a, torch.tensor([0.25, 0.5, 0.75])),
         torch.std(a),
         torch.std_mean(a),
         torch.sum(a),
         torch.unique(a),
         torch.unique_consecutive(a),
         torch.var(a),
         torch.var_mean(a),
         torch.count_nonzero(a),
     )
Пример #4
0
def generate_overlap_vad_seq_per_tensor(
    frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str
) -> torch.Tensor:
    """
    Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments
    See description in generate_overlap_vad_seq.
    Use this for single instance pipeline. 
    """
    # This function will be refactor for vectorization but this is okay for now

    overlap = per_args['overlap']
    window_length_in_sec = per_args['window_length_in_sec']
    shift_length_in_sec = per_args['shift_length_in_sec']
    frame_len = per_args.get('frame_len', 0.01)

    shift = int(shift_length_in_sec / frame_len)  # number of units of shift
    seg = int((window_length_in_sec / frame_len + 1))  # number of units of each window/segment

    jump_on_target = int(seg * (1 - overlap))  # jump on target generated sequence
    jump_on_frame = int(jump_on_target / shift)  # jump on input frame sequence

    if jump_on_frame < 1:
        raise ValueError(
            f"Note we jump over frame sequence to generate overlapping input segments. \n \
        Your input makes jump_on_frame={jump_on_frame} < 1 which is invalid because it cannot jump and will stuck.\n \
        Please try different window_length_in_sec, shift_length_in_sec and overlap choices. \n \
        jump_on_target = int(seg * (1 - overlap)) \n \
        jump_on_frame  = int(jump_on_frame/shift) "
        )

    target_len = int(len(frame) * shift)

    if smoothing_method == 'mean':
        preds = torch.zeros(target_len)
        pred_count = torch.zeros(target_len)

        for i, og_pred in enumerate(frame):
            if i % jump_on_frame != 0:
                continue
            start = i * shift
            end = start + seg
            preds[start:end] = preds[start:end] + og_pred
            pred_count[start:end] = pred_count[start:end] + 1

        preds = preds / pred_count
        last_non_zero_pred = preds[pred_count != 0][-1]
        preds[pred_count == 0] = last_non_zero_pred

    elif smoothing_method == 'median':
        preds = [torch.empty(0) for _ in range(target_len)]
        for i, og_pred in enumerate(frame):
            if i % jump_on_frame != 0:
                continue

            start = i * shift
            end = start + seg
            for j in range(start, end):
                if j <= target_len - 1:
                    preds[j] = torch.cat((preds[j], og_pred.unsqueeze(0)), 0)

        preds = torch.stack([torch.nanquantile(l, q=0.5) for l in preds])
        nan_idx = torch.isnan(preds)
        last_non_nan_pred = preds[~nan_idx][-1]
        preds[nan_idx] = last_non_nan_pred

    else:
        raise ValueError("smoothing_method should be either mean or median")

    return preds