def with_alignment_from_ctm(self, ctm_file: Pathlike, type: str = 'word', match_channel: bool = False) -> 'SupervisionSet': """ Add alignments from CTM file to the supervision set. :param ctm: Path to CTM file. :param type: Alignment type (optional, default = `word`). :param match_channel: if True, also match channel between CTM and SupervisionSegment :return: A new SupervisionSet with AlignmentItem objects added to the segments. """ ctm_words = [] with open(ctm_file) as f: for line in f: reco_id, channel, start, duration, symbol = line.strip().split() ctm_words.append((reco_id, int(channel), float(start), float(duration), symbol)) ctm_words = sorted(ctm_words, key=lambda x:(x[0], x[2])) reco_to_ctm = defaultdict(list, {k: list(v) for k,v in groupby(ctm_words, key=lambda x:x[0])}) segments = [] num_total = len(ctm_words) num_overspanned = 0 for reco_id in set([s.recording_id for s in self]): if reco_id in reco_to_ctm: for seg in self.find(recording_id=reco_id): alignment = [AlignmentItem(symbol=word[4], start=word[2], duration=word[3]) for word in reco_to_ctm[reco_id] if overspans(seg, TimeSpan(word[2], word[2] + word[3])) and (seg.channel == word[1] or not match_channel) ] num_overspanned += len(alignment) segments.append(fastcopy(seg, alignment={type: alignment})) else: segments.append([s for s in self.find(recording_id=reco_id)]) logging.info(f"{num_overspanned} alignments added out of {num_total} total. If there are several" " missing, there could be a mismatch problem.") return SupervisionSet.from_segments(segments)
def test_overspans(lhs, rhs, expected): assert overspans(lhs, rhs) == expected