예제 #1
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data_train
    global meta_data_eval

    ap = AudioProcessor(**c.audio)
    model = setup_model(c)

    optimizer = RAdam(model.parameters(), lr=c.lr)

    # pylint: disable=redefined-outer-name
    meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=False)

    data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)

    if c.loss == "ge2e":
        criterion = GE2ELoss(loss_method="softmax")
    elif c.loss == "angleproto":
        criterion = AngleProtoLoss()
    elif c.loss == "softmaxproto":
        criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers)
    else:
        raise Exception("The %s  not is a loss supported" % c.loss)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            model.load_state_dict(checkpoint["model"])

            if "criterion" in checkpoint:
                criterion.load_state_dict(checkpoint["criterion"])

        except (KeyError, RuntimeError):
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model"], c)
            model.load_state_dict(model_dict)
            del model_dict
        for group in optimizer.param_groups:
            group["lr"] = c.lr

        print(" > Model restored from step %d" % checkpoint["step"], flush=True)
        args.restore_step = checkpoint["step"]
    else:
        args.restore_step = 0

    if c.lr_decay:
        scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if use_cuda:
        model = model.cuda()
        criterion.cuda()

    global_step = args.restore_step
    _, global_step = train(model, optimizer, scheduler, criterion, data_loader, global_step)
예제 #2
0
 def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
     self.speaker_encoder_config = load_config(config_path)
     self.speaker_encoder = setup_model(self.speaker_encoder_config)
     self.speaker_encoder.load_checkpoint(config_path, model_path, True)
     self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
     # normalize the input audio level and trim silences
     self.speaker_encoder_ap.do_sound_norm = True
     self.speaker_encoder_ap.do_trim_silence = True
예제 #3
0
    def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
        """Initialize a speaker encoder model.

        Args:
            model_path (str): Model file path.
            config_path (str): Model config file path.
        """
        self.speaker_encoder_config = load_config(config_path)
        self.speaker_encoder = setup_model(self.speaker_encoder_config)
        self.speaker_encoder.load_checkpoint(config_path,
                                             model_path,
                                             eval=True,
                                             use_cuda=self.use_cuda)
        self.speaker_encoder_ap = AudioProcessor(
            **self.speaker_encoder_config.audio)
예제 #4
0
    def test_speaker_embedding():
        # load config
        config = load_config(encoder_config_path)
        config.audio.resample = True

        # create a dummy speaker encoder
        model = setup_model(config)
        save_checkpoint(model, None, None, get_tests_input_path(), 0)

        # load audio processor and speaker encoder
        ap = AudioProcessor(**config.audio)
        manager = SpeakerManager(encoder_model_path=encoder_model_path,
                                 encoder_config_path=encoder_config_path)

        # load a sample audio and compute embedding
        waveform = ap.load_wav(sample_wav_path)
        mel = ap.melspectrogram(waveform)
        d_vector = manager.compute_d_vector(mel.T)
        assert d_vector.shape[1] == 256

        # compute d_vector directly from an input file
        d_vector = manager.compute_d_vector_from_clip(sample_wav_path)
        d_vector2 = manager.compute_d_vector_from_clip(sample_wav_path)
        d_vector = torch.FloatTensor(d_vector)
        d_vector2 = torch.FloatTensor(d_vector2)
        assert d_vector.shape[0] == 256
        assert (d_vector - d_vector2).sum() == 0.0

        # compute d_vector from a list of wav files.
        d_vector3 = manager.compute_d_vector_from_clip(
            [sample_wav_path, sample_wav_path2])
        d_vector3 = torch.FloatTensor(d_vector3)
        assert d_vector3.shape[0] == 256
        assert (d_vector - d_vector3).sum() != 0.0

        # remove dummy model
        os.remove(encoder_model_path)
예제 #5
0
            for line in f:
                components = line.split(sep)
                if len(components) != 2:
                    print("Invalid line")
                    continue
                wav_file = os.path.join(wav_path, components[0] + ".wav")
                # print(f'wav_file: {wav_file}')
                if os.path.exists(wav_file):
                    wav_files.append(wav_file)
        print(f"Count of wavs imported: {len(wav_files)}")
    else:
        # Parse all wav files in data_path
        wav_files = glob.glob(data_path + "/**/*.wav", recursive=True)

# define Encoder model
model = setup_model(c)
model.load_state_dict(torch.load(args.model_path)["model"])
model.eval()
if args.use_cuda:
    model.cuda()

# compute speaker embeddings
speaker_mapping = {}
for idx, wav_file in enumerate(tqdm(wav_files)):
    if isinstance(wav_file, list):
        speaker_name = wav_file[2]
        wav_file = wav_file[1]
    else:
        speaker_name = None

    mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T