Exemplo n.º 1
0
    def test_in_out(self):
        dummy_input = T.rand(4, 20, 80)  # B x T x D
        dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
        model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
        # computing d vectors
        output = model.forward(dummy_input)
        assert output.shape[0] == 4
        assert output.shape[1] == 256
        output = model.forward(dummy_input, l2_norm=True)
        assert output.shape[0] == 4
        assert output.shape[1] == 256

        # check normalization
        output_norm = T.nn.functional.normalize(output, dim=1, p=2)
        assert_diff = (output_norm - output).sum().item()
        assert output.type() == "torch.FloatTensor"
        assert abs(
            assert_diff
        ) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
        # compute d for a given batch
        dummy_input = T.rand(1, 240, 80)  # B x T x D
        output = model.compute_embedding(dummy_input,
                                         num_frames=160,
                                         num_eval=10)
        assert output.shape[0] == 1
        assert output.shape[1] == 256
        assert len(output.shape) == 2
Exemplo n.º 2
0
def setup_model(c):
    if c.model_params["model_name"].lower() == "lstm":
        model = LSTMSpeakerEncoder(
            c.model_params["input_dim"],
            c.model_params["proj_dim"],
            c.model_params["lstm_dim"],
            c.model_params["num_lstm_layers"],
        )
    elif c.model_params["model_name"].lower() == "resnet":
        model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"])
    return model
Exemplo n.º 3
0
def setup_speaker_encoder_model(config: "Coqpit"):
    if config.model_params["model_name"].lower() == "lstm":
        model = LSTMSpeakerEncoder(
            config.model_params["input_dim"],
            config.model_params["proj_dim"],
            config.model_params["lstm_dim"],
            config.model_params["num_lstm_layers"],
            use_torch_spec=config.model_params.get("use_torch_spec", False),
            audio_config=config.audio,
        )
    elif config.model_params["model_name"].lower() == "resnet":
        model = ResNetSpeakerEncoder(
            input_dim=config.model_params["input_dim"],
            proj_dim=config.model_params["proj_dim"],
            log_input=config.model_params.get("log_input", False),
            use_torch_spec=config.model_params.get("use_torch_spec", False),
            audio_config=config.audio,
        )
    return model