def test_Encoder_forward_backward( input_layer, positionwise_layer_type, interctc_layer_idx, interctc_use_conditioning, ): encoder = TransformerEncoder( 20, output_size=40, input_layer=input_layer, positionwise_layer_type=positionwise_layer_type, interctc_layer_idx=interctc_layer_idx, interctc_use_conditioning=interctc_use_conditioning, ) if input_layer == "embed": x = torch.randint(0, 10, [2, 10]) else: x = torch.randn(2, 10, 20, requires_grad=True) x_lens = torch.LongTensor([10, 8]) if len(interctc_layer_idx) > 0: ctc = None if interctc_use_conditioning: vocab_size = 5 output_size = encoder.output_size() ctc = CTC(odim=vocab_size, encoder_output_size=output_size) encoder.conditioning_layer = torch.nn.Linear( vocab_size, output_size) y, _, _ = encoder(x, x_lens, ctc=ctc) y = y[0] else: y, _, _ = encoder(x, x_lens) y.sum().backward()
def test_encoder_invalid_interctc_layer_idx(): with pytest.raises(AssertionError): TransformerEncoder( 20, num_blocks=2, interctc_layer_idx=[0, 1], ) with pytest.raises(AssertionError): TransformerEncoder( 20, num_blocks=2, interctc_layer_idx=[1, 2], )
def __init__(self, config): super(TransformerTransducer, self).__init__() self.vocab_size = config.joint.vocab_size self.sos = self.vocab_size - 1 self.eos = self.vocab_size - 1 self.ignore_id = -1 self.encoder_left_mask = config.mask.encoder_left_mask self.encoder_right_mask = config.mask.encoder_right_mask self.decoder_left_mask = config.mask.decoder_left_mask self.encoder = TransformerEncoder(**config.enc) self.decoder = TransformerEncoder(**config.dec) self.joint = JointNetwork(**config.joint) self.loss = TransLoss(trans_type="warp-transducer", blank_id=0) # todo: check blank id
def test_Encoder_forward_backward(input_layer, positionwise_layer_type): encoder = TransformerEncoder( 20, output_size=40, input_layer=input_layer, positionwise_layer_type=positionwise_layer_type, ) if input_layer == "embed": x = torch.randint(0, 10, [2, 10]) elif input_layer is None: x = torch.randn(2, 10, 40, requires_grad=True) else: x = torch.randn(2, 10, 20, requires_grad=True) x_lens = torch.LongTensor([10, 8]) y, _, _ = encoder(x, x_lens) y.sum().backward()
class TransformerTransducer(nn.Module): def __init__(self, config): super(TransformerTransducer, self).__init__() self.vocab_size = config.joint.vocab_size self.sos = self.vocab_size - 1 self.eos = self.vocab_size - 1 self.ignore_id = -1 self.encoder_left_mask = config.mask.encoder_left_mask self.encoder_right_mask = config.mask.encoder_right_mask self.decoder_left_mask = config.mask.decoder_left_mask self.encoder = TransformerEncoder(**config.enc) self.decoder = TransformerEncoder(**config.dec) self.joint = JointNetwork(**config.joint) self.loss = TransLoss(trans_type="warp-transducer", blank_id=0) # todo: check blank id def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ): """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) # 1. Encoder encoder_out, encoder_out_lens, _ = self.encoder( speech, speech_lengths, left_mask=self.encoder_left_mask, right_mask=self.encoder_right_mask) # return xs_pad, olens, None # 2. Decoder # todo: train right shift text_in, text_out = add_sos_eos(text, self.sos, self.eos, self.ignore_id) text_in_lens = text_lengths + 1 decoder_out, decoder_out_lens, _ = self.decoder( text_in, text_in_lens, left_mask=self.decoder_left_mask, right_mask=0) # return xs_pad, olens, None # 3.Joint # h_enc: Batch of expanded hidden state (B, T, 1, D_enc) # h_dec: Batch of expanded hidden state (B, 1, U, D_dec) encoder_out = encoder_out.unsqueeze(2) decoder_out = decoder_out.unsqueeze(1) joint_out = self.joint(h_enc=encoder_out, h_dec=decoder_out) # 4.loss # pred_pad (torch.Tensor): Batch of predicted sequences loss = self.loss( pred_pad=joint_out, # (batch, maxlen_in, maxlen_out+1, odim) target=text.int(), # (batch, maxlen_out) pred_len=speech_lengths.int(), # (batch) target_len=text_lengths.int()) # (batch) return loss @torch.no_grad() def decode(self, enc_state, lengths): # token_list = [] token_list = [self.sos] device = torch.device("cuda" if enc_state.is_cuda else "cpu") token = torch.tensor([token_list], dtype=torch.long).to(device) decoder_out, decoder_out_lens, _ = self.decoder.forward_one_step( token, self.decoder_left_mask) decoder_out = decoder_out[:, -1, :] for t in range(lengths): logits = self.joint(enc_state[t].view(-1), decoder_out.view(-1)) out = F.softmax(logits, dim=0).detach() pred = torch.argmax(out, dim=0) pred = int(pred.item()) if pred != 0: # blank_id token_list.append(pred) token = torch.tensor([token_list], dtype=torch.long) if enc_state.is_cuda: token = token.cuda() decoder_out, decoder_out_lens, _ = self.decoder.forward_one_step( token) # 历史信息输入,但是只取最后一个输出 decoder_out = decoder_out[:, -1, :] # return token_list return token_list[1:] @torch.no_grad() def recognize(self, speech: torch.Tensor, speech_lengths: torch.Tensor) -> list: batch_size = speech.size(0) encoder_out, encoder_out_lens, _ = self.encoder( speech, speech_lengths, left_mask=self.encoder_left_mask, right_mask=self.encoder_right_mask) results = [] for batch in range(batch_size): decoded_seq = self.decode(encoder_out[batch], speech_lengths[batch]) results.append(decoded_seq) return results
def __init__( self, n_fft: int = 256, win_length: int = None, hop_length: int = 128, dnn_type: str = "transformer", #layer: int = 3, #unit: int = 512, dropout: float = 0.0, num_spk: int = 2, nonlinear: str = "sigmoid", utt_mvn: bool = False, mask_type: str = "IRM", loss_type: str = "mask_mse", d_model: int = 256, nhead: int = 4, linear_units: int = 2048, num_layers: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: Optional[str] = "linear", pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, positionwise_layer_type: str = "linear", positionwise_conv_kernel_size: int = 1, padding_idx: int = -1, ): super(TFMaskingTransformer, self).__init__() self.num_spk = num_spk self.num_bin = n_fft // 2 + 1 self.mask_type = mask_type self.loss_type = loss_type if loss_type not in ("mask_mse", "magnitude", "spectrum"): raise ValueError("Unsupported loss type: %s" % loss_type) self.stft = Stft(n_fft=n_fft, win_length=win_length, hop_length=hop_length,) if utt_mvn: self.utt_mvn = UtteranceMVN(norm_means=True, norm_vars=True) else: self.utt_mvn = None #self.rnn = RNN( # idim=self.num_bin, # elayers=layer, # cdim=unit, # hdim=unit, # dropout=dropout, # typ=rnn_type, #) self.encoder = TransformerEncoder( input_size=self.num_bin, output_size=d_model, attention_heads=nhead, linear_units=linear_units, num_blocks=num_layers, positional_dropout_rate=positional_dropout_rate, attention_dropout_rate=attention_dropout_rate, input_layer=input_layer, normalize_before=normalize_before, concat_after=concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, padding_idx=padding_idx, ) self.linear = torch.nn.ModuleList( [torch.nn.Linear(d_model, self.num_bin) for _ in range(self.num_spk)] ) if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.nonlinear = { "sigmoid": torch.nn.Sigmoid(), "relu": torch.nn.ReLU(), "tanh": torch.nn.Tanh(), }[nonlinear]
def test_Encoder_invalid_type(): with pytest.raises(ValueError): TransformerEncoder(20, input_layer="fff")
def test_Encoder_output_size(): encoder = TransformerEncoder(20, output_size=256) assert encoder.output_size() == 256
fix_order_solver = FixedOrderSolver(criterion=si_snr_loss) default_frontend = DefaultFrontend( fs=300, n_fft=32, win_length=32, hop_length=24, n_mels=32, ) token_list = ["<blank>", "<space>", "a", "e", "i", "o", "u", "<sos/eos>"] asr_transformer_encoder = TransformerEncoder( 32, output_size=16, linear_units=16, num_blocks=2, ) asr_transformer_decoder = TransformerDecoder( len(token_list), 16, linear_units=16, num_blocks=2, ) asr_ctc = CTC(odim=len(token_list), encoder_output_size=16) @pytest.mark.parametrize( "enh_encoder, enh_decoder",
from espnet2.diar.attractor.rnn_attractor import RnnAttractor from espnet2.diar.decoder.linear_decoder import LinearDecoder from espnet2.diar.espnet_model import ESPnetDiarizationModel from espnet2.layers.label_aggregation import LabelAggregate frontend = DefaultFrontend( n_fft=32, win_length=32, hop_length=16, n_mels=10, ) encoder = TransformerEncoder( input_size=10, input_layer="linear", num_blocks=1, linear_units=32, output_size=16, attention_heads=2, ) decoder = LinearDecoder( num_spk=2, encoder_output_size=encoder.output_size(), ) rnn_attractor = RnnAttractor(unit=16, encoder_output_size=encoder.output_size()) label_aggregator = LabelAggregate( win_length=32, hop_length=16,