def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global meta_data_train global meta_data_eval ap = AudioProcessor(**c.audio) model = setup_model(c) optimizer = RAdam(model.parameters(), lr=c.lr) # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=False) data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True) if c.loss == "ge2e": criterion = GE2ELoss(loss_method="softmax") elif c.loss == "angleproto": criterion = AngleProtoLoss() elif c.loss == "softmaxproto": criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers) else: raise Exception("The %s not is a loss supported" % c.loss) if args.restore_path: checkpoint = torch.load(args.restore_path) try: model.load_state_dict(checkpoint["model"]) if "criterion" in checkpoint: criterion.load_state_dict(checkpoint["criterion"]) except (KeyError, RuntimeError): print(" > Partial model initialization.") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: group["lr"] = c.lr print(" > Model restored from step %d" % checkpoint["step"], flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 if c.lr_decay: scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None num_params = count_parameters(model) print("\n > Model has {} parameters".format(num_params), flush=True) if use_cuda: model = model.cuda() criterion.cuda() global_step = args.restore_step _, global_step = train(model, optimizer, scheduler, criterion, data_loader, global_step)
def init_speaker_encoder(self, model_path: str, config_path: str) -> None: self.speaker_encoder_config = load_config(config_path) self.speaker_encoder = setup_model(self.speaker_encoder_config) self.speaker_encoder.load_checkpoint(config_path, model_path, True) self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio) # normalize the input audio level and trim silences self.speaker_encoder_ap.do_sound_norm = True self.speaker_encoder_ap.do_trim_silence = True
def init_speaker_encoder(self, model_path: str, config_path: str) -> None: """Initialize a speaker encoder model. Args: model_path (str): Model file path. config_path (str): Model config file path. """ self.speaker_encoder_config = load_config(config_path) self.speaker_encoder = setup_model(self.speaker_encoder_config) self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda) self.speaker_encoder_ap = AudioProcessor( **self.speaker_encoder_config.audio)
def test_speaker_embedding(): # load config config = load_config(encoder_config_path) config.audio.resample = True # create a dummy speaker encoder model = setup_model(config) save_checkpoint(model, None, None, get_tests_input_path(), 0) # load audio processor and speaker encoder ap = AudioProcessor(**config.audio) manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path) # load a sample audio and compute embedding waveform = ap.load_wav(sample_wav_path) mel = ap.melspectrogram(waveform) d_vector = manager.compute_d_vector(mel.T) assert d_vector.shape[1] == 256 # compute d_vector directly from an input file d_vector = manager.compute_d_vector_from_clip(sample_wav_path) d_vector2 = manager.compute_d_vector_from_clip(sample_wav_path) d_vector = torch.FloatTensor(d_vector) d_vector2 = torch.FloatTensor(d_vector2) assert d_vector.shape[0] == 256 assert (d_vector - d_vector2).sum() == 0.0 # compute d_vector from a list of wav files. d_vector3 = manager.compute_d_vector_from_clip( [sample_wav_path, sample_wav_path2]) d_vector3 = torch.FloatTensor(d_vector3) assert d_vector3.shape[0] == 256 assert (d_vector - d_vector3).sum() != 0.0 # remove dummy model os.remove(encoder_model_path)
for line in f: components = line.split(sep) if len(components) != 2: print("Invalid line") continue wav_file = os.path.join(wav_path, components[0] + ".wav") # print(f'wav_file: {wav_file}') if os.path.exists(wav_file): wav_files.append(wav_file) print(f"Count of wavs imported: {len(wav_files)}") else: # Parse all wav files in data_path wav_files = glob.glob(data_path + "/**/*.wav", recursive=True) # define Encoder model model = setup_model(c) model.load_state_dict(torch.load(args.model_path)["model"]) model.eval() if args.use_cuda: model.cuda() # compute speaker embeddings speaker_mapping = {} for idx, wav_file in enumerate(tqdm(wav_files)): if isinstance(wav_file, list): speaker_name = wav_file[2] wav_file = wav_file[1] else: speaker_name = None mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T