コード例 #1
0
ファイル: video_timestamps.py プロジェクト: mmazab/LifeQA
def peaks(boxes: np.ndarray, samples_per_box: int) -> Mapping[int, Iterable[int]]:
    boxes = boxes.transpose((2, 0, 1, 3))

    horiz_bin_count, vert_bin_count, box_width, box_height = boxes.shape

    boxes = boxes.reshape(boxes.shape[:2] + (-1,))

    indices_peaks = boxes.argpartition(- samples_per_box)[..., - samples_per_box:]

    # We compute the frequency and sample indices for the peaks.
    indices_freqs = np.arange(horiz_bin_count)[:, np.newaxis, np.newaxis] * box_height + indices_peaks % box_height
    indices_samples = np.arange(vert_bin_count)[np.newaxis, :, np.newaxis] * box_width + indices_peaks // box_height

    freqs_dict = defaultdict(list)
    for freq_idx, sample_idx in np.nditer([indices_freqs, indices_samples]):
        freqs_dict[freq_idx.item()].append(sample_idx.item())
    return freqs_dict