def test_Encoder_forward_backward(
    input_layer,
    positionwise_layer_type,
    interctc_layer_idx,
    interctc_use_conditioning,
):
    encoder = TransformerEncoder(
        20,
        output_size=40,
        input_layer=input_layer,
        positionwise_layer_type=positionwise_layer_type,
        interctc_layer_idx=interctc_layer_idx,
        interctc_use_conditioning=interctc_use_conditioning,
    )
    if input_layer == "embed":
        x = torch.randint(0, 10, [2, 10])
    else:
        x = torch.randn(2, 10, 20, requires_grad=True)
    x_lens = torch.LongTensor([10, 8])
    if len(interctc_layer_idx) > 0:
        ctc = None
        if interctc_use_conditioning:
            vocab_size = 5
            output_size = encoder.output_size()
            ctc = CTC(odim=vocab_size, encoder_output_size=output_size)
            encoder.conditioning_layer = torch.nn.Linear(
                vocab_size, output_size)
        y, _, _ = encoder(x, x_lens, ctc=ctc)
        y = y[0]
    else:
        y, _, _ = encoder(x, x_lens)
    y.sum().backward()
def test_encoder_invalid_interctc_layer_idx():
    with pytest.raises(AssertionError):
        TransformerEncoder(
            20,
            num_blocks=2,
            interctc_layer_idx=[0, 1],
        )
    with pytest.raises(AssertionError):
        TransformerEncoder(
            20,
            num_blocks=2,
            interctc_layer_idx=[1, 2],
        )
Example #3
0
    def __init__(self, config):
        super(TransformerTransducer, self).__init__()
        self.vocab_size = config.joint.vocab_size
        self.sos = self.vocab_size - 1
        self.eos = self.vocab_size - 1
        self.ignore_id = -1
        self.encoder_left_mask = config.mask.encoder_left_mask
        self.encoder_right_mask = config.mask.encoder_right_mask
        self.decoder_left_mask = config.mask.decoder_left_mask

        self.encoder = TransformerEncoder(**config.enc)
        self.decoder = TransformerEncoder(**config.dec)
        self.joint = JointNetwork(**config.joint)
        self.loss = TransLoss(trans_type="warp-transducer",
                              blank_id=0)  # todo: check blank id
def test_Encoder_forward_backward(input_layer, positionwise_layer_type):
    encoder = TransformerEncoder(
        20,
        output_size=40,
        input_layer=input_layer,
        positionwise_layer_type=positionwise_layer_type,
    )
    if input_layer == "embed":
        x = torch.randint(0, 10, [2, 10])
    elif input_layer is None:
        x = torch.randn(2, 10, 40, requires_grad=True)
    else:
        x = torch.randn(2, 10, 20, requires_grad=True)
    x_lens = torch.LongTensor([10, 8])
    y, _, _ = encoder(x, x_lens)
    y.sum().backward()
Example #5
0
class TransformerTransducer(nn.Module):
    def __init__(self, config):
        super(TransformerTransducer, self).__init__()
        self.vocab_size = config.joint.vocab_size
        self.sos = self.vocab_size - 1
        self.eos = self.vocab_size - 1
        self.ignore_id = -1
        self.encoder_left_mask = config.mask.encoder_left_mask
        self.encoder_right_mask = config.mask.encoder_right_mask
        self.decoder_left_mask = config.mask.decoder_left_mask

        self.encoder = TransformerEncoder(**config.enc)
        self.decoder = TransformerEncoder(**config.dec)
        self.joint = JointNetwork(**config.joint)
        self.loss = TransLoss(trans_type="warp-transducer",
                              blank_id=0)  # todo: check blank id

    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ):
        """Frontend + Encoder + Decoder + Calc loss

        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
                text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                         text.shape, text_lengths.shape)
        # 1. Encoder
        encoder_out, encoder_out_lens, _ = self.encoder(
            speech,
            speech_lengths,
            left_mask=self.encoder_left_mask,
            right_mask=self.encoder_right_mask)  # return xs_pad, olens, None

        # 2. Decoder
        # todo: train right shift
        text_in, text_out = add_sos_eos(text, self.sos, self.eos,
                                        self.ignore_id)
        text_in_lens = text_lengths + 1
        decoder_out, decoder_out_lens, _ = self.decoder(
            text_in,
            text_in_lens,
            left_mask=self.decoder_left_mask,
            right_mask=0)  # return xs_pad, olens, None

        # 3.Joint
        # h_enc: Batch of expanded hidden state (B, T, 1, D_enc)
        # h_dec: Batch of expanded hidden state (B, 1, U, D_dec)
        encoder_out = encoder_out.unsqueeze(2)
        decoder_out = decoder_out.unsqueeze(1)
        joint_out = self.joint(h_enc=encoder_out, h_dec=decoder_out)

        # 4.loss
        # pred_pad (torch.Tensor): Batch of predicted sequences
        loss = self.loss(
            pred_pad=joint_out,  # (batch, maxlen_in, maxlen_out+1, odim)
            target=text.int(),  # (batch, maxlen_out)
            pred_len=speech_lengths.int(),  # (batch)
            target_len=text_lengths.int())  # (batch)
        return loss

    @torch.no_grad()
    def decode(self, enc_state, lengths):
        # token_list = []
        token_list = [self.sos]
        device = torch.device("cuda" if enc_state.is_cuda else "cpu")
        token = torch.tensor([token_list], dtype=torch.long).to(device)
        decoder_out, decoder_out_lens, _ = self.decoder.forward_one_step(
            token, self.decoder_left_mask)
        decoder_out = decoder_out[:, -1, :]
        for t in range(lengths):
            logits = self.joint(enc_state[t].view(-1), decoder_out.view(-1))
            out = F.softmax(logits, dim=0).detach()
            pred = torch.argmax(out, dim=0)
            pred = int(pred.item())

            if pred != 0:  # blank_id
                token_list.append(pred)
                token = torch.tensor([token_list], dtype=torch.long)
                if enc_state.is_cuda:
                    token = token.cuda()
                decoder_out, decoder_out_lens, _ = self.decoder.forward_one_step(
                    token)  # 历史信息输入,但是只取最后一个输出
                decoder_out = decoder_out[:, -1, :]
        # return token_list
        return token_list[1:]

    @torch.no_grad()
    def recognize(self, speech: torch.Tensor,
                  speech_lengths: torch.Tensor) -> list:
        batch_size = speech.size(0)
        encoder_out, encoder_out_lens, _ = self.encoder(
            speech,
            speech_lengths,
            left_mask=self.encoder_left_mask,
            right_mask=self.encoder_right_mask)
        results = []
        for batch in range(batch_size):
            decoded_seq = self.decode(encoder_out[batch],
                                      speech_lengths[batch])
            results.append(decoded_seq)
        return results
Example #6
0
    def __init__(
        self,
        n_fft: int = 256, 
        win_length: int = None,
        hop_length: int = 128,
        dnn_type: str = "transformer",
        #layer: int = 3,
        #unit: int = 512,
        dropout: float = 0.0,
        num_spk: int = 2,
        nonlinear: str = "sigmoid",
        utt_mvn: bool = False,
        mask_type: str = "IRM",
        loss_type: str = "mask_mse",
        d_model: int = 256,
        nhead: int = 4,
        linear_units: int = 2048,
        num_layers: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: Optional[str] = "linear",
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 1,
        padding_idx: int = -1,
    
    ):
        super(TFMaskingTransformer, self).__init__()

        self.num_spk = num_spk
        self.num_bin = n_fft // 2 + 1
        self.mask_type = mask_type
        self.loss_type = loss_type
        if loss_type not in ("mask_mse", "magnitude", "spectrum"):
            raise ValueError("Unsupported loss type: %s" % loss_type)

        self.stft = Stft(n_fft=n_fft, win_length=win_length, hop_length=hop_length,)

        if utt_mvn:
            self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True)

        else:
            self.utt_mvn = None

        #self.rnn = RNN(
        #    idim=self.num_bin,
        #    elayers=layer,
        #    cdim=unit,
        #    hdim=unit,
        #    dropout=dropout,
        #    typ=rnn_type,
        #)

        self.encoder = TransformerEncoder(
             input_size=self.num_bin,
             output_size=d_model,
             attention_heads=nhead,
             linear_units=linear_units,
             num_blocks=num_layers,
             positional_dropout_rate=positional_dropout_rate,
             attention_dropout_rate=attention_dropout_rate,
             input_layer=input_layer,
             normalize_before=normalize_before,
             concat_after=concat_after,
             positionwise_layer_type=positionwise_layer_type,
             positionwise_conv_kernel_size=positionwise_conv_kernel_size,
             padding_idx=padding_idx,
        )
        self.linear = torch.nn.ModuleList(
            [torch.nn.Linear(d_model, self.num_bin) for _ in range(self.num_spk)]
        )

        if nonlinear not in ("sigmoid", "relu", "tanh"):
            raise ValueError("Not supporting nonlinear={}".format(nonlinear))

        self.nonlinear = {
            "sigmoid": torch.nn.Sigmoid(),
            "relu": torch.nn.ReLU(),
            "tanh": torch.nn.Tanh(),
        }[nonlinear]
def test_Encoder_invalid_type():
    with pytest.raises(ValueError):
        TransformerEncoder(20, input_layer="fff")
def test_Encoder_output_size():
    encoder = TransformerEncoder(20, output_size=256)
    assert encoder.output_size() == 256
Example #9
0
fix_order_solver = FixedOrderSolver(criterion=si_snr_loss)

default_frontend = DefaultFrontend(
    fs=300,
    n_fft=32,
    win_length=32,
    hop_length=24,
    n_mels=32,
)

token_list = ["<blank>", "<space>", "a", "e", "i", "o", "u", "<sos/eos>"]

asr_transformer_encoder = TransformerEncoder(
    32,
    output_size=16,
    linear_units=16,
    num_blocks=2,
)

asr_transformer_decoder = TransformerDecoder(
    len(token_list),
    16,
    linear_units=16,
    num_blocks=2,
)

asr_ctc = CTC(odim=len(token_list), encoder_output_size=16)


@pytest.mark.parametrize(
    "enh_encoder, enh_decoder",
Example #10
0
from espnet2.diar.attractor.rnn_attractor import RnnAttractor
from espnet2.diar.decoder.linear_decoder import LinearDecoder
from espnet2.diar.espnet_model import ESPnetDiarizationModel
from espnet2.layers.label_aggregation import LabelAggregate

frontend = DefaultFrontend(
    n_fft=32,
    win_length=32,
    hop_length=16,
    n_mels=10,
)

encoder = TransformerEncoder(
    input_size=10,
    input_layer="linear",
    num_blocks=1,
    linear_units=32,
    output_size=16,
    attention_heads=2,
)

decoder = LinearDecoder(
    num_spk=2,
    encoder_output_size=encoder.output_size(),
)

rnn_attractor = RnnAttractor(unit=16,
                             encoder_output_size=encoder.output_size())

label_aggregator = LabelAggregate(
    win_length=32,
    hop_length=16,
Example #11
0
fix_order_solver = FixedOrderSolver(criterion=si_snr_loss)

default_frontend = DefaultFrontend(
    fs=300,
    n_fft=32,
    win_length=32,
    hop_length=24,
    n_mels=32,
)

token_list = ["<blank>", "<space>", "a", "e", "i", "o", "u", "<sos/eos>"]

asr_transformer_encoder = TransformerEncoder(
    32,
    output_size=16,
    linear_units=16,
    num_blocks=2,
)

asr_transformer_decoder = TransformerDecoder(
    len(token_list),
    16,
    linear_units=16,
    num_blocks=2,
)

asr_ctc = CTC(odim=len(token_list), encoder_output_size=16)


@pytest.mark.parametrize(
    "enh_encoder, enh_decoder",