def __init__( self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, ordered_prefetch=False, cache_size=327680, ): super().__init__() assert len(utt_ids) == len(rxfiles) self.dtype = np.int16 self.utt_ids = utt_ids self.rxfiles = rxfiles self.size = len(utt_ids) # number of utterances self.sizes = [] # length of each utterance if utt2num_frames is not None and len(utt2num_frames) > 0: assert len(utt2num_frames) == self.size self.sizes = utt2num_frames if len(self.sizes) == 0: for rxfile in self.rxfiles: try: ali = kaldi_io.read_vec_int(rxfile) except Exception: raise Exception( "failed to read int vector {}.".format(rxfile)) assert ali is not None and isinstance(ali, np.ndarray) self.sizes.append(ali.shape[0]) assert len(self.sizes) == self.size self.sizes = np.array(self.sizes, dtype=np.int32) self.cache = None self.cache_index = {} self.cache_size = cache_size # in terms of number of examples self.start_pos_for_next_cache = 0 self.ordered_indices = list(range(self.size)) # set to True ONLY if examples are queried in the same order as # self.ordered_indices, and doing this will speed up search of the # queried index self.ordered_prefetch = ordered_prefetch
def __getitem__(self, i): self.check_index(i) if i not in self.cache_index: assert self.start_pos_for_next_cache < \ len(self.ordered_indices), \ "Position for next cache starting beyond the end of ordered_indices." try: pos_start = self.ordered_indices.index( i, self.start_pos_for_next_cache, ) except ValueError: raise ValueError( "index {} not found in self.ordered_indices. Set " "self.ordered_prefetch to False, and/or call self.prefetch() " "with the full list of indices, and then try again.". format(i)) pos_end = min( pos_start + self.cache_size, len(self.ordered_indices), ) self.start_pos_for_next_cache = pos_end \ if self.ordered_prefetch else 0 total_size = 0 for idx in self.ordered_indices[pos_start:pos_end]: total_size += self.sizes[idx] self.cache = np.empty(total_size, dtype=self.dtype) ptx = 0 self.cache_index.clear() for idx in self.ordered_indices[pos_start:pos_end]: self.cache_index[idx] = ptx length = self.sizes[idx] dst = self.cache[ptx:ptx + length] np.copyto(dst, kaldi_io.read_vec_int(self.rxfiles[idx])) ptx += length ptx = self.cache_index[i] a = self.cache[ptx:ptx + self.sizes[i]].copy() return torch.from_numpy(a).long()
def main(args): assert args.prior_floor > 0.0 and args.prior_floor < 1.0 prior = np.zeros((args.prior_dim, ), dtype=np.int32) for path in args.alignment_files: with open(path, "r", encoding="utf-8") as f: for line in f: _, rxfile = line.strip().split(None, 1) try: ali = kaldi_io.read_vec_int(rxfile) except Exception: raise Exception( "failed to read int vector {}.".format(rxfile)) assert ali is not None and isinstance(ali, np.ndarray) for id in ali: prior[id] += 1 prior = np.maximum(prior / float(np.sum(prior)), args.prior_floor) # normalize and floor prior = prior / float(np.sum(prior)) # normalize again kaldi_io.write_vec_flt(args.output, prior) logger.info("Saved the initial state prior estimate in {}".format( args.output))
plda = PLDA() plda.Read(args.plda_file) dim = plda.Dim() spk2num_utts = {} if args.num_utts != '': if not os.path.exists(args.num_utts): raise FileExistsError(args.num_utts) else: with kaldi_io.open_or_fd(args.num_utts) as f: while True: key = kaldi_io.read_key(f) if key != "": value = kaldi_io.read_vec_int(f) spk2num_utts[key] = value[0] else: break sub_mean = False if args.subtract_global_mean: if args.mean_vec != "" and os.path.exists(args.mean_vec): with open(args.mean_vec, 'rb') as f: try: global_mean = _read_vec_flt_binary(f) sub_mean = True except UnknownVectorHeader as u: mean_vec = [] vec_str = f.readline() for v in vec_str.split():