def print_result(self, filename, loglikes, words): logger.info(f"decoding wav file: {str(Path(filename).resolve())}") if self.verbose: labels = onehot2int(loglikes).squeeze() logger.info(f"labels: {' '.join([str(x) for x in labels.tolist()])}") symbols = [self.decoder.labeler.idx2phone(x) for x in remove_duplicates(labels, blank=0)] logger.info(f"symbols: {' '.join(symbols)}") words = words.squeeze() text = ' '.join([self.decoder.labeler.idx2word(i) for i in words]) \ if words.dim() else '<null output from decoder>' logger.info(f"decoded text: {text}")
def unit_validate(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat, frame_lens = self.model(xs, frame_lens) if self.fp16: ys_hat = ys_hat.float() # convert likes to ctc labels hyps = [onehot2int(yh[:s]).squeeze() for yh, s in zip(ys_hat, frame_lens)] hyps = [remove_duplicates(h, blank=0) for h in hyps] # slice the targets pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(label_lens, dim=0))) refs = [ys[s:l] for s, l in zip(pos[:-1], pos[1:])] return hyps, refs
def make_ctc_labels(self): # find *.phn files logger.info(f"finding *.phn files under {str(self.target_path)}") phn_files = [str(x) for x in self.target_path.rglob("*.phn")] # convert for phn_file in tqdm(phn_files, ncols=params.NCOLS): phns = np.loadtxt(phn_file, dtype="int", ndmin=1) # make ctc labelings by removing duplications ctcs = np.array([x for x in remove_duplicates(phns)]) # write ctc file # blank labels will be inserted in warp-ctc loss module, # so here the target labels have not to contain the blanks interleaved ctc_file = phn_file.replace("phn", "ctc") np.savetxt(str(ctc_file), ctcs, "%d") count_priors(phn_files)