예제 #1
0
 def print_result(self, filename, loglikes, words):
     logger.info(f"decoding wav file: {str(Path(filename).resolve())}")
     if self.verbose:
         labels = onehot2int(loglikes).squeeze()
         logger.info(f"labels: {' '.join([str(x) for x in labels.tolist()])}")
         symbols = [self.decoder.labeler.idx2phone(x) for x in remove_duplicates(labels, blank=0)]
         logger.info(f"symbols: {' '.join(symbols)}")
     words = words.squeeze()
     text = ' '.join([self.decoder.labeler.idx2word(i) for i in words]) \
            if words.dim() else '<null output from decoder>'
     logger.info(f"decoded text: {text}")
예제 #2
0
 def unit_validate(self, data):
     xs, ys, frame_lens, label_lens, filenames, _ = data
     if self.use_cuda:
         xs = xs.cuda(non_blocking=True)
     ys_hat, frame_lens = self.model(xs, frame_lens)
     if self.fp16:
         ys_hat = ys_hat.float()
     # convert likes to ctc labels
     hyps = [onehot2int(yh[:s]).squeeze() for yh, s in zip(ys_hat, frame_lens)]
     hyps = [remove_duplicates(h, blank=0) for h in hyps]
     # slice the targets
     pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(label_lens, dim=0)))
     refs = [ys[s:l] for s, l in zip(pos[:-1], pos[1:])]
     return hyps, refs
예제 #3
0
 def make_ctc_labels(self):
     # find *.phn files
     logger.info(f"finding *.phn files under {str(self.target_path)}")
     phn_files = [str(x) for x in self.target_path.rglob("*.phn")]
     # convert
     for phn_file in tqdm(phn_files, ncols=params.NCOLS):
         phns = np.loadtxt(phn_file, dtype="int", ndmin=1)
         # make ctc labelings by removing duplications
         ctcs = np.array([x for x in remove_duplicates(phns)])
         # write ctc file
         # blank labels will be inserted in warp-ctc loss module,
         # so here the target labels have not to contain the blanks interleaved
         ctc_file = phn_file.replace("phn", "ctc")
         np.savetxt(str(ctc_file), ctcs, "%d")
     count_priors(phn_files)