def inference( output_dir: str, batch_size: int, dtype: str, fs: int, ngpu: int, seed: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], train_config: Optional[str], model_file: Optional[str], model_tag: Optional[str], inference_config: Optional[str], allow_variable_data_keys: bool, segment_size: Optional[float], hop_size: Optional[float], normalize_segment_scale: bool, show_progressbar: bool, ref_channel: Optional[int], normalize_output_wav: bool, enh_s2t_task: bool, ): assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build separate_speech separate_speech_kwargs = dict( train_config=train_config, model_file=model_file, inference_config=inference_config, segment_size=segment_size, hop_size=hop_size, normalize_segment_scale=normalize_segment_scale, show_progressbar=show_progressbar, ref_channel=ref_channel, normalize_output_wav=normalize_output_wav, device=device, dtype=dtype, enh_s2t_task=enh_s2t_task, ) separate_speech = SeparateSpeech.from_pretrained( model_tag=model_tag, **separate_speech_kwargs, ) # 3. Build data-iterator loader = EnhancementTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=EnhancementTask.build_preprocess_fn( separate_speech.enh_train_args, False), collate_fn=EnhancementTask.build_collate_fn( separate_speech.enh_train_args, False), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) # 4. Start for-loop output_dir = Path(output_dir).expanduser().resolve() writers = [] for i in range(separate_speech.num_spk): writers.append( SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp")) for i, (keys, batch) in enumerate(loader): logging.info(f"[{i}] Enhancing {keys}") assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} waves = separate_speech(**batch) for (spk, w) in enumerate(waves): for b in range(batch_size): writers[spk][keys[b]] = fs, w[b] for writer in writers: writer.close()
def inference( output_dir: str, batch_size: int, dtype: str, fs: int, ngpu: int, seed: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], enh_train_config: str, enh_model_file: str, allow_variable_data_keys: bool, normalize_output_wav: bool, ): assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build Enh model enh_model, enh_train_args = EnhancementTask.build_model_from_file( enh_train_config, enh_model_file, device) enh_model.eval() num_spk = enh_model.num_spk # 3. Build data-iterator loader = EnhancementTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=EnhancementTask.build_preprocess_fn( enh_train_args, False), collate_fn=EnhancementTask.build_collate_fn(enh_train_args), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) writers = [] for i in range(num_spk): writers.append( SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp")) for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" with torch.no_grad(): # a. To device batch = to_device(batch, device) # b. Forward Enhancement Frontend waves, _, _ = enh_model.enh_model.forward_rawwav( batch["speech_mix"], batch["speech_mix_lengths"]) assert len(waves[0]) == batch_size, len(waves[0]) # FIXME(Chenda): will be incorrect when # batch size is not 1 or multi-channel case if normalize_output_wav: waves = [ (w / abs(w).max(dim=1, keepdim=True)[0] * 0.9).T.cpu().numpy() for w in waves ] # list[(sample,batch)] else: waves = [w.T.cpu().numpy() for w in waves] for (i, w) in enumerate(waves): writers[i][keys[0]] = fs, w for writer in writers: writer.close()