def recog(self, xs_pad, ilens): assert xs_pad.size(0) == ilens.size(0), "Batch size mismatch" batch_size = xs_pad.size(0) enc_pad, enc_lens = self.extract_feat(xs_pad, ilens) enc_pad = enc_pad.transpose(0,1) # T * B * d_model enc_pad = self.pos_encoder(enc_pad) enc_pad_mask = make_bool_pad_mask(enc_lens) # T * B * d_model (if True: means padding) enc_pad_mask = to_device(self, enc_pad_mask) decode_maxlen = enc_lens.max().item() memory = self.encoder(enc_pad, src_key_padding_mask = enc_pad_mask) # TODO: should we sample during decoding, or just use the logit?? # use sampled char then embed currently # Init sos = torch.ones(1, batch_size).fill_(self.sos_id).to(dtype=torch.int64) sos = to_device(self,sos) # 1 * B out = to_device(self, torch.tensor([], dtype=torch.int64)) for decode_step in range(1,decode_maxlen+1): ys_in_pad = self.pre_embed(torch.cat([sos, out])) # (decode_step) * B * d_model ys_in_pad = self.pos_encoder(ys_in_pad) tgt_self_attn_mask = to_device(self, generate_square_subsequent_mask(ys_in_pad.size(0))) out = self.decoder(ys_in_pad, memory, tgt_mask = tgt_self_attn_mask, memory_key_padding_mask = enc_pad_mask) out = self.char_trans(out) out = torch.argmax(out, dim=-1) return out # L * B
def enc_forward(self, xs_pad, ilens, ys_pad, olens): assert xs_pad.size(0) == ilens.size(0) == len(ys_pad) == olens.size(0), "Batch size mismatch" batch_size = xs_pad.size(0) xs_pad = to_device(self,xs_pad) ## VGG forward xs_pad = xs_pad.view(batch_size, xs_pad.size(1), 1, xs_pad.size(2)).transpose(1,2) # B * 1 * T * D(83) xs_pad = self.feat_extractor(xs_pad) # B * T * D' (128 * 20) ilens = torch.floor(ilens.to(dtype=torch.float32)/4).to(dtype=torch.int64) xs_pad = xs_pad.transpose(1,2) xs_pad = xs_pad.contiguous().view(batch_size, xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)) xs_pad = self.vgg2enc(xs_pad) # B * T * d_model ## TransformerEncoder forward xs_pad = xs_pad.transpose(0,1) # T * B * d_model xs_pad = self.pos_encoder(xs_pad) pad_mask = make_bool_pad_mask(ilens) # T * B * d_model (if True: means padding) pad_mask = to_device(self, pad_mask) enc_pad = self.encoder(xs_pad, src_key_padding_mask=pad_mask) # T * B * d_model return enc_pad, ilens
def forward(self, xs_pad, ilens, ys, olens): assert xs_pad.size(0) == ilens.size(0) == len(ys) == olens.size(0), "Batch size mismatch" enc_pad, enc_lens = self.extract_feat(xs_pad, ilens) enc_pad = enc_pad.transpose(0,1) # T * B * d_model enc_pad = self.pos_encoder(enc_pad) # enc_pad_mask will be used in src_key_padding_mask # TODO: should we use enc_pad_mask for memory_key_padding_mask, ENABLE it now enc_pad_mask = make_bool_pad_mask(enc_lens) # T * B * d_model (if True: means padding) enc_pad_mask = to_device(self, enc_pad_mask) ys_in_pad, ys_out_pad, olens = self.preprocess(ys, olens) # ys_in_pad: L * B * d_model # tgt_self_attn_mask will be used in tgt_mask (subsequent mask) tgt_self_attn_mask = to_device(self, generate_square_subsequent_mask(ys_in_pad.size(0))) # tgt_pad_mask will be used in tgt_key_padding_mask tgt_pad_mask = to_device(self, make_bool_pad_mask(olens)) memory = self.encoder(enc_pad, src_key_padding_mask = enc_pad_mask) out = self.decoder(ys_in_pad, memory, tgt_mask=tgt_self_attn_mask, memory_key_padding_mask=enc_pad_mask) # out: L * B * d_model out = out.transpose(0,1) # B * L * d_model logit = self.char_trans(out) # B * L * odim return logit, ys_out_pad
def enc_forward(self, xs_pad, ilens, probe_layer): xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), 1, xs_pad.size(2)).transpose(1, 2) xs_pad = self.vgg(xs_pad) if torch.is_tensor(ilens): ilens = ilens.cpu().numpy() else: ilens = np.array(ilens, dtype=np.float32) enc_lens = np.array(np.ceil(ilens / 2), dtype=np.int64) enc_lens = np.array(np.ceil(np.array(enc_lens, dtype=np.float32) / 2), dtype=np.int64).tolist() xs_pad = xs_pad.transpose(1, 2) xs_pad = xs_pad.contiguous().view(xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)) if probe_layer == 0: # return VGG embedding mask = to_device(self, make_pad_mask(enc_lens).unsqueeze(-1)) return xs_pad.masked_fill(mask, .0), enc_lens out, enc_lens, cur_state = self.blstm.enc_forward( xs_pad, enc_lens, probe_layer) mask = to_device(self, make_pad_mask(enc_lens).unsqueeze(-1)) return out.masked_fill(mask, .0), enc_lens
def enc_forward(self, xs_pad, ilens, probe_layer): assert xs_pad.size(0) == ilens.size(0), "Batch size mismatch" batch_size = xs_pad.size(0) xs_pad = to_device(self, xs_pad) ilens = to_device(self, ilens) enc, enc_lens = self.encoder.enc_forward(xs_pad, ilens, probe_layer) return enc, enc_lens
def forward(self, comp_listner_feat, enc_lens, dec_h, att_prev, scaling=2.0): """AttLoc forward :param torch.Tensor comp_listner_feat: padded encoder hidden state (B, Tmax, enc_o_dim) :param list of int enc_lens :param dec_h: decoder hidden state (B, dec_dim) :param att_prev: previous attention scores (att_w) (B,Tmax) :param float scaling: scaling parameter before applying softmax :return: attention weighted decoder state, context (B, dec_dim) :rtype: torch.Tensor :return: attention score (B, Tmax) :rtype: torch.Tensor """ batch_size = comp_listner_feat.size(0) # pre-compute all h outside the decoder loop if self.precomp_enc_h is None: self.enc_h = comp_listner_feat # (B, T, enc_dim) self.enc_len_max = self.enc_h.size(1) # Only calculate once self.precomp_enc_h = self.mlp_enc(self.enc_h) # (B, T, att_dim) dec_h = dec_h.view(batch_size, self.dec_dim) # initialize attention weight with uniform dist. if att_prev is None: att_prev = to_device(self, (~make_pad_mask(enc_lens)).float()) # att_prev = att_prev / att_prev.sum(dim=1).unsqueeze(-1) att_prev = att_prev / enc_lens.float().unsqueeze(-1) # (B, T) att_conv = self.loc_conv( att_prev.view(batch_size, 1, 1, self.enc_len_max)) # (B, C, 1, T) att_conv = att_conv.squeeze(2).transpose(1, 2) # (B, T, C) att_conv = self.mlp_att(att_conv) # (B, T, att_dim) dec_h = self.mlp_dec(dec_h).view(batch_size, 1, self.att_dim) # (B, 1, att_dim) e = self.gvec(torch.tanh(att_conv + self.precomp_enc_h + dec_h)).squeeze(2) # (B, T) if self.mask is None: self.mask = to_device(self, make_pad_mask(enc_lens)) # masked out e according to mask e.masked_fill_(self.mask, -float('inf')) w = F.softmax(scaling * e, dim=1) #(B, T) c = torch.bmm(w.view(batch_size, 1, self.enc_len_max), self.enc_h).squeeze(1) # (B, enc_dim) return c, w
def forward(self, xs_pad, enc_lens, prev_state=None): """RNNP forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, vgg_o_dim) :param list of int enc_lens :param torch.Tensor prev_state: batch of previous RNN states :return: batch of hidden state sequences (B, Tmax, odim) :rtype: torch.Tensor """ hid_states = [] for i in range(self.nlayers): xs_pack = pack_padded_sequence(xs_pad, enc_lens, batch_first=True) rnn = getattr(self, f"rnn{i}") rnn.flatten_parameters() if prev_state is not None: prev_state = reset_backward_rnn_state(prev_state) ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[i]) hid_states.append(states) ys_pad, enc_lens = pad_packed_sequence(ys, batch_first=True) # ys_pad: (B, T, enc_dim) sub = self.sample_rate[i] if sub > 1: ys_pad = ys_pad[:, ::sub] enc_lens = torch.LongTensor( [int(i + 1) // sub for i in enc_lens]) projected = getattr(self, f"bt{i}")(ys_pad.contiguous().view( -1, ys_pad.size(2))) #(B*T, proj_dim) xs_pad = torch.tanh( projected.view(ys_pad.size(0), ys_pad.size(1), -1)) return xs_pad, to_device(self, enc_lens), hid_states
def run_batch(self, cur_b, x, ilens, ys, olens, train): sos = ys[0].new([self.sos_id]) eos = ys[0].new([self.eos_id]) ys_out = [torch.cat([sos, y, eos], dim=0) for y in ys] olens += 2 # pad <sos> and <eos> y_true = torch.cat(ys_out) pred, enc_lens = self.asr_model(x, ilens) olens = to_device(self.asr_model, olens) pred = F.log_softmax(pred, dim=-1) # (T, o_dim) loss = self.ctc_loss( pred.transpose(0, 1).contiguous(), y_true.cuda().to(dtype=torch.long), enc_lens.cpu().to(dtype=torch.long), olens.cpu().to(dtype=torch.long)) if train: info = {'loss': loss.item()} # if self.global_step % 5 == 0: if self.global_step % 500 == 0: self.probe_model(pred, ys) self.asr_opt.zero_grad() loss.backward() else: cer = self.metric_observer.batch_cal_er( pred.detach(), ys, ['ctc'], ['cer'])['ctc_cer'] wer = self.metric_observer.batch_cal_er( pred.detach(), ys, ['ctc'], ['wer'])['ctc_wer'] info = {'cer': cer, 'wer': wer, 'loss': loss.item()} return info
def forward(self, xs_pad, ilens): """ xs_pad: (B, Tmax, 83) #83 is 80-dim fbank + 3-dim pitch ilens: torch.Tensor with size B """ assert xs_pad.size(0) == ilens.size(0), "Batch size mismatch" batch_size = xs_pad.size(0) # Put data to device xs_pad = to_device(self, xs_pad) ilens = to_device(self, ilens) enc, enc_lens, _, _ = self.encoder(xs_pad, ilens) out = self.head(enc) return out, enc_lens
def enc_forward(self, xs_pad, enc_lens, probe_layer, prev_state=None): hid_states = [] assert probe_layer <= self.nlayers,\ f"Probe layer ({probe_layer}) can't exceed # of layers ({self.nlayers})" for i in range(probe_layer): xs_pack = pack_padded_sequence(xs_pad, enc_lens, batch_first=True) rnn = getattr(self, f"rnn{i}") rnn.flatten_parameters() if prev_state is not None: prev_state = reset_backward_rnn_state(prev_state) ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[i]) hid_states.append(states) ys_pad, enc_lens = pad_packed_sequence(ys, batch_first=True) # ys_pad: (B, T, enc_dim) sub = self.sample_rate[i] if sub > 1: ys_pad = ys_pad[:, ::sub] enc_lens = torch.LongTensor( [int(i + 1) // sub for i in enc_lens]) projected = getattr(self, f"bt{i}")(ys_pad.contiguous().view( -1, ys_pad.size(2))) #(B*T, proj_dim) xs_pad = torch.tanh( projected.view(ys_pad.size(0), ys_pad.size(1), -1)) return xs_pad, to_device(self, enc_lens), hid_states
def beam_decode(self, x, ilen): assert x.size(0) == 1, "Batch size should be 1 in beam_decode" x = to_device(self, x) ilen = to_device(self, ilen) enc, enc_len, _, _ = self.encoder(x, ilen) out = self.head(enc) #beam_decode_ans = self._beam_search(out, enc_len, decode_beam_size) beam_decode_ans = self.tf_beam_decode(out, enc_len, decode_beam_size, nbest=decode_beam_size) #embed() return beam_decode_ans, enc_len
def preprocess(self, ys, olens): #TODO: should we set ys_out_pad to eos ot ignore_id sos = ys[0].new([self.sos_id]) eos = ys[0].new([self.eos_id]) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_in_pad = to_device(self,pad_sequence(ys_in, padding_value=self.eos_id, batch_first=False)) # L*B ys_in_pad = self.pre_embed(ys_in_pad) # L * B * d_model #TODO: should we add logit_scale of sqrt(d_model) here?? ys_in_pad = self.pos_encoder(ys_in_pad) ys_out = [torch.cat([y, eos], dim=0) for y in ys] ys_out_pad = to_device(self,pad_sequence(ys_out, padding_value=IGNORE_ID, batch_first=True)) # B*L # ys_out_pad = pad_sequence(ys, padding_value=self.eos_id, batch_first=True) # B * L olens += 1 return ys_in_pad, ys_out_pad, olens
def extract_feat(self, xs_pad, ilens): batch_size = len(ilens) xs_pad = to_device(self,xs_pad.view(xs_pad.size(0), xs_pad.size(1), 1, xs_pad.size(2)).transpose(1,2)) enc_pad = self.feat_extractor(xs_pad) # B * T * D' (128 * 20) enc_lens = torch.floor(ilens.to(dtype=torch.float32)/4).to(dtype=torch.int64) enc_pad = enc_pad.transpose(1,2) enc_pad = enc_pad.contiguous().view(batch_size, enc_pad.size(1), enc_pad.size(2) * enc_pad.size(3)) enc_pad = self.vgg2enc(enc_pad) # B * T * d_model return enc_pad, enc_lens
def preprocess(self, ys): """Generate decoder input and output label from padded_input Add <sos> to decoder input, and add <eos> to decoder output label ys: list, len B olens: B """ # ys = [y[y != IGNORE_ID] for y in padded_input] # parse padded ys # prepare input and output word sequences with sos/eos IDs sos = ys[0].new([self.sos_id]) eos = ys[0].new([self.eos_id]) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] # padding for ys with -1 # pys: utt x olen ys_in_pad = pad_list(ys_in, self.eos_id) ys_out_pad = pad_list(ys_out, IGNORE_ID) ys_in_pad = to_device(self, ys_in_pad) ys_out_pad = to_device(self, ys_out_pad) assert ys_in_pad.size() == ys_out_pad.size() return ys_in_pad, ys_out_pad
def forward(self, xs_pad, ilens, prev_state=None): """Encoder forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, 83) :param torch.IntTensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) :return: batch of hidden state sequences (B, Tmax, odim) :rtype: torch.Tensor """ out, enc_lens = self.vgg(xs_pad, ilens) out, enc_lens, cur_state = self.blstm(out, enc_lens, prev_state) mask = to_device(self, make_pad_mask(enc_lens).unsqueeze(-1)) return out.masked_fill(mask, .0), enc_lens, cur_state
def extract_feature(self, xs_pad, ilens): batch_size = xs_pad.size(0) xs_pad = to_device(self, xs_pad) xs_pad = xs_pad.view(batch_size, xs_pad.size(1), 1, xs_pad.size(2)).transpose(1, 2) # B * 1 * T * D(83) xs_pad = self.feat_extractor(xs_pad) # B * T * D' (128 * 20) ilens = torch.floor(ilens.to(dtype=torch.float32) / 4).to(dtype=torch.int64) xs_pad = xs_pad.transpose(1, 2) xs_pad = xs_pad.contiguous().view(batch_size, xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)) return xs_pad, ilens
def forward(self, xs_pad, ilens, prev_state=None): xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), 1, xs_pad.size(2)).transpose(1, 2) xs_pad = self.vgg(xs_pad) if torch.is_tensor(ilens): ilens = ilens.cpu().numpy() else: ilens = np.array(ilens, dtype=np.float32) enc_lens = np.array(np.ceil(ilens / 2), dtype=np.int64) enc_lens = np.array(np.ceil(np.array(enc_lens, dtype=np.float32) / 2), dtype=np.int64).tolist() xs_pad = xs_pad.transpose(1, 2) xs_pad = xs_pad.contiguous().view(xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)) vgg_out = xs_pad.clone() out, enc_lens, cur_state = self.blstm(xs_pad, enc_lens, prev_state) mask = to_device(self, make_pad_mask(enc_lens).unsqueeze(-1)) return out.masked_fill(mask, .0), enc_lens, cur_state, vgg_out
def forward(self, xs_pad, ilens, ys_pad, decode_step, tf=False, tf_rate=0.0, vis_att=False): """ xs_pad: (B, Tmax, 83) #83 is 80-dim fbank + 3-dim pitch ilens: torch.Tensor with size B ys_pad: (B, Lmax), will parse in list of tensors in ctc """ # NOTE: in each forward, we at most visualize one utterance attention # map (since I want to compare different length effect), and choose the # longest utterance in this batch for simplicity assert xs_pad.size(0) == ilens.size(0) == ys_pad.size(0), "Batch size mismatch" batch_size = xs_pad.size(0) # Put data to device xs_pad = to_device(self,xs_pad) ilens = to_device(self,ilens) ys_pad = to_device(self, ys_pad) enc, enc_lens, _ = self.encoder(xs_pad, ilens) ctc_output = self.ctc_layer(enc) sos = ys_pad.new([self.sos_id]) eos = ys_pad.new([self.eos_id]) if tf: # ys_pad_in = torch.cat((sos.repeat(batch_size).view(batch_size,-1),ys_pad),1) # teacher = self.embed(ys_pad_in) #(B, L+1, dec_dim) teacher = self.embed(ys_pad) #(B, L, dec_dim) self.decoder.init_rnn(enc) self.attention.reset() last_char_emb = self.embed(sos.repeat(batch_size))# B * dec_dim output_char_seq = list() if vis_att: output_att_seq = list() att_w = None for t in range(decode_step): att_c, att_w = self.attention(enc, enc_lens, self.decoder.state_list[0], att_w) dec_inp = torch.cat([last_char_emb, att_c], dim=-1) # (B, dec_dim + enc_o_dim) dec_out = self.decoder(dec_inp) #(B, dec_dim) cur_char = self.char_trans(dec_out) #(B, odim) # Teacher forcing if tf and t < decode_step - 1: # if random.random() < tf_rate: if torch.rand(1).item() < tf_rate: last_char_emb = teacher[:,t,:] else: # scheduled sampling sampled_char = Categorical(F.softmax(cur_char,dim=-1)).sample() last_char_emb = self.embed(sampled_char) else: # greedy pick last_char_emb = self.embed(torch.argmax(cur_char,dim=-1)) output_char_seq.append(cur_char) if vis_att: output_att_seq.append(att_w.detach()[0].view(enc_lens[0],1)) att_output = torch.stack(output_char_seq, dim=1).view(batch_size*decode_step,-1) # att_output = torch.stack(output_char_seq, dim=1) # (B,T,o_dim) att_map = torch.stack(output_att_seq, dim=1) if vis_att else None # att_map = torch.stack(output_att_seq,dim=1) # (T,L) return ctc_output, enc_lens, att_output, att_map