Esempio n. 1
0
def test_frontend_backward():
    frontend = DefaultFrontend(fs=160,
                               n_fft=128,
                               win_length=32,
                               hop_length=32,
                               frontend_conf=None)
    x = torch.randn(2, 300, requires_grad=True)
    x_lengths = torch.LongTensor([300, 89])
    y, y_lengths = frontend(x, x_lengths)
    y.sum().backward()
Esempio n. 2
0
def test_frontend_backward_multi_channel(train, use_wpe, use_beamformer):
    frontend = DefaultFrontend(
        fs=300,
        n_fft=128,
        win_length=128,
        frontend_conf={"use_wpe": use_wpe, "use_beamformer": use_beamformer},
    )
    if train:
        frontend.train()
    else:
        frontend.eval()
    x = torch.randn(2, 1000, 2, requires_grad=True)
    x_lengths = torch.LongTensor([1000, 980])
    y, y_lengths = frontend(x, x_lengths)
    y.sum().backward()
Esempio n. 3
0
enh_rnn_separator = RNNSeparator(
    input_dim=17,
    layer=1,
    unit=10,
    num_spk=1,
)

si_snr_loss = SISNRLoss()

fix_order_solver = FixedOrderSolver(criterion=si_snr_loss)

default_frontend = DefaultFrontend(
    fs=300,
    n_fft=32,
    win_length=32,
    hop_length=24,
    n_mels=32,
)

token_list = ["<blank>", "<space>", "a", "e", "i", "o", "u", "<sos/eos>"]

asr_transformer_encoder = TransformerEncoder(
    32,
    output_size=16,
    linear_units=16,
    num_blocks=2,
)

asr_transformer_decoder = TransformerDecoder(
    len(token_list),
Esempio n. 4
0
    def __init__(self,
                 frontends=None,
                 align_method="linear_projection",
                 proj_dim=100,
                 fs=16000):

        assert check_argument_types()
        super().__init__()
        self.align_method = (
            align_method  # fusing method : linear_projection only for now
        )
        self.proj_dim = proj_dim  # dim of the projection done on each frontend
        self.frontends = []  # list of the frontends to combine

        for i, frontend in enumerate(frontends):
            frontend_type = frontend["frontend_type"]
            if frontend_type == "default":
                n_mels, fs, n_fft, win_length, hop_length = (
                    frontend.get("n_mels", 80),
                    fs,
                    frontend.get("n_fft", 512),
                    frontend.get("win_length"),
                    frontend.get("hop_length", 128),
                )
                window, center, normalized, onesided = (
                    frontend.get("window", "hann"),
                    frontend.get("center", True),
                    frontend.get("normalized", False),
                    frontend.get("onesided", True),
                )
                fmin, fmax, htk, apply_stft = (
                    frontend.get("fmin", None),
                    frontend.get("fmax", None),
                    frontend.get("htk", False),
                    frontend.get("apply_stft", True),
                )

                self.frontends.append(
                    DefaultFrontend(
                        n_mels=n_mels,
                        n_fft=n_fft,
                        fs=fs,
                        win_length=win_length,
                        hop_length=hop_length,
                        window=window,
                        center=center,
                        normalized=normalized,
                        onesided=onesided,
                        fmin=fmin,
                        fmax=fmax,
                        htk=htk,
                        apply_stft=apply_stft,
                    ))
            elif frontend_type == "s3prl":
                frontend_conf, download_dir, multilayer_feature = (
                    frontend.get("frontend_conf"),
                    frontend.get("download_dir"),
                    frontend.get("multilayer_feature"),
                )
                self.frontends.append(
                    S3prlFrontend(
                        fs=fs,
                        frontend_conf=frontend_conf,
                        download_dir=download_dir,
                        multilayer_feature=multilayer_feature,
                    ))

            else:
                raise NotImplementedError  # frontends are only default or s3prl

        self.frontends = torch.nn.ModuleList(self.frontends)

        self.gcd = np.gcd.reduce(
            [frontend.hop_length for frontend in self.frontends])
        self.factors = [
            frontend.hop_length // self.gcd for frontend in self.frontends
        ]
        if torch.cuda.is_available():
            dev = "cuda"
        else:
            dev = "cpu"
        if self.align_method == "linear_projection":
            self.projection_layers = [
                torch.nn.Linear(
                    in_features=frontend.output_size(),
                    out_features=self.factors[i] * self.proj_dim,
                ) for i, frontend in enumerate(self.frontends)
            ]
            self.projection_layers = torch.nn.ModuleList(
                self.projection_layers)
            self.projection_layers = self.projection_layers.to(
                torch.device(dev))
Esempio n. 5
0
def test_frontend_repr():
    frontend = DefaultFrontend(fs="16k")
    print(frontend)
Esempio n. 6
0
def test_frontend_output_size():
    frontend = DefaultFrontend(fs="16k", n_mels=40)
    assert frontend.output_size() == 40
Esempio n. 7
0
import pytest
import torch

from espnet2.asr.encoder.transformer_encoder import TransformerEncoder
from espnet2.asr.frontend.default import DefaultFrontend
from espnet2.diar.attractor.rnn_attractor import RnnAttractor
from espnet2.diar.decoder.linear_decoder import LinearDecoder
from espnet2.diar.espnet_model import ESPnetDiarizationModel
from espnet2.layers.label_aggregation import LabelAggregate

frontend = DefaultFrontend(
    n_fft=32,
    win_length=32,
    hop_length=16,
    n_mels=10,
)

encoder = TransformerEncoder(
    input_size=10,
    input_layer="linear",
    num_blocks=1,
    linear_units=32,
    output_size=16,
    attention_heads=2,
)

decoder = LinearDecoder(
    num_spk=2,
    encoder_output_size=encoder.output_size(),
)
Esempio n. 8
0
enh_rnn_separator = RNNSeparator(
    input_dim=17,
    layer=1,
    unit=10,
    num_spk=1,
)

si_snr_loss = SISNRLoss()

fix_order_solver = FixedOrderSolver(criterion=si_snr_loss)

default_frontend = DefaultFrontend(
    fs=300,
    n_fft=32,
    win_length=32,
    hop_length=24,
    n_mels=32,
)

token_list = ["<blank>", "<space>", "a", "e", "i", "o", "u", "<sos/eos>"]

asr_transformer_encoder = TransformerEncoder(
    32,
    output_size=16,
    linear_units=16,
    num_blocks=2,
)

asr_transformer_decoder = TransformerDecoder(
    len(token_list),