def recog(self, xs_pad, ilens):
        assert xs_pad.size(0) == ilens.size(0), "Batch size mismatch"
        batch_size = xs_pad.size(0)

        enc_pad, enc_lens = self.extract_feat(xs_pad, ilens)
        enc_pad = enc_pad.transpose(0,1) # T * B * d_model
        enc_pad = self.pos_encoder(enc_pad)
        enc_pad_mask = make_bool_pad_mask(enc_lens) # T * B * d_model (if True: means padding)
        enc_pad_mask = to_device(self, enc_pad_mask)

        decode_maxlen = enc_lens.max().item()
        
        memory = self.encoder(enc_pad, src_key_padding_mask = enc_pad_mask)

        # TODO: should we sample during decoding, or just use the logit??
        # use sampled char then embed currently

        # Init
        sos = torch.ones(1, batch_size).fill_(self.sos_id).to(dtype=torch.int64)
        sos = to_device(self,sos) # 1 * B
        out = to_device(self, torch.tensor([], dtype=torch.int64))

        for decode_step in range(1,decode_maxlen+1):
            ys_in_pad = self.pre_embed(torch.cat([sos, out])) # (decode_step) * B * d_model
            ys_in_pad = self.pos_encoder(ys_in_pad)
            tgt_self_attn_mask = to_device(self, generate_square_subsequent_mask(ys_in_pad.size(0)))

            out = self.decoder(ys_in_pad, memory, 
                               tgt_mask = tgt_self_attn_mask,
                               memory_key_padding_mask = enc_pad_mask)
            out = self.char_trans(out)
            out = torch.argmax(out, dim=-1)

        return out # L * B
    def enc_forward(self, xs_pad, ilens, ys_pad, olens):

        assert xs_pad.size(0) == ilens.size(0) == len(ys_pad) == olens.size(0), "Batch size mismatch"
        batch_size = xs_pad.size(0)

        xs_pad = to_device(self,xs_pad)

        ## VGG forward
        xs_pad = xs_pad.view(batch_size, xs_pad.size(1), 1, xs_pad.size(2)).transpose(1,2) # B * 1 * T * D(83)
        xs_pad = self.feat_extractor(xs_pad) # B * T * D' (128 * 20)
        ilens = torch.floor(ilens.to(dtype=torch.float32)/4).to(dtype=torch.int64)
        xs_pad = xs_pad.transpose(1,2)
        xs_pad = xs_pad.contiguous().view(batch_size,  xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
        xs_pad = self.vgg2enc(xs_pad) # B * T * d_model


        ## TransformerEncoder forward
        xs_pad = xs_pad.transpose(0,1) # T * B * d_model
        xs_pad = self.pos_encoder(xs_pad)
        pad_mask = make_bool_pad_mask(ilens) # T * B * d_model (if True: means padding)
        pad_mask = to_device(self, pad_mask)

        enc_pad = self.encoder(xs_pad, src_key_padding_mask=pad_mask) # T * B * d_model

        return enc_pad, ilens
    def forward(self, xs_pad, ilens, ys, olens):
        assert xs_pad.size(0) == ilens.size(0) == len(ys) == olens.size(0), "Batch size mismatch"

        enc_pad, enc_lens = self.extract_feat(xs_pad, ilens)

        enc_pad = enc_pad.transpose(0,1) # T * B * d_model
        enc_pad = self.pos_encoder(enc_pad)

        # enc_pad_mask will be used in src_key_padding_mask
        # TODO: should we use enc_pad_mask for memory_key_padding_mask, ENABLE it now
        enc_pad_mask = make_bool_pad_mask(enc_lens) # T * B * d_model (if True: means padding)
        enc_pad_mask = to_device(self, enc_pad_mask)

        ys_in_pad, ys_out_pad, olens = self.preprocess(ys, olens)
        # ys_in_pad: L * B * d_model

        # tgt_self_attn_mask will be used in tgt_mask (subsequent mask)
        tgt_self_attn_mask = to_device(self, generate_square_subsequent_mask(ys_in_pad.size(0)))

        # tgt_pad_mask will be used in tgt_key_padding_mask
        tgt_pad_mask = to_device(self, make_bool_pad_mask(olens))

        memory = self.encoder(enc_pad, src_key_padding_mask = enc_pad_mask)
        out = self.decoder(ys_in_pad, memory, 
                           tgt_mask=tgt_self_attn_mask,
                           memory_key_padding_mask=enc_pad_mask)
        # out: L * B * d_model
        out = out.transpose(0,1) # B * L * d_model
        logit = self.char_trans(out) # B * L * odim

        return logit, ys_out_pad
    def enc_forward(self, xs_pad, ilens, probe_layer):
        xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), 1,
                             xs_pad.size(2)).transpose(1, 2)
        xs_pad = self.vgg(xs_pad)

        if torch.is_tensor(ilens):
            ilens = ilens.cpu().numpy()
        else:
            ilens = np.array(ilens, dtype=np.float32)
        enc_lens = np.array(np.ceil(ilens / 2), dtype=np.int64)
        enc_lens = np.array(np.ceil(np.array(enc_lens, dtype=np.float32) / 2),
                            dtype=np.int64).tolist()

        xs_pad = xs_pad.transpose(1, 2)
        xs_pad = xs_pad.contiguous().view(xs_pad.size(0), xs_pad.size(1),
                                          xs_pad.size(2) * xs_pad.size(3))

        if probe_layer == 0:  # return VGG embedding
            mask = to_device(self, make_pad_mask(enc_lens).unsqueeze(-1))
            return xs_pad.masked_fill(mask, .0), enc_lens

        out, enc_lens, cur_state = self.blstm.enc_forward(
            xs_pad, enc_lens, probe_layer)

        mask = to_device(self, make_pad_mask(enc_lens).unsqueeze(-1))
        return out.masked_fill(mask, .0), enc_lens
    def enc_forward(self, xs_pad, ilens, probe_layer):
        assert xs_pad.size(0) == ilens.size(0), "Batch size mismatch"
        batch_size = xs_pad.size(0)

        xs_pad = to_device(self, xs_pad)
        ilens = to_device(self, ilens)

        enc, enc_lens = self.encoder.enc_forward(xs_pad, ilens, probe_layer)
        return enc, enc_lens
    def forward(self,
                comp_listner_feat,
                enc_lens,
                dec_h,
                att_prev,
                scaling=2.0):
        """AttLoc forward

        :param torch.Tensor comp_listner_feat: padded encoder hidden state (B, Tmax, enc_o_dim)
        :param list of int enc_lens
        :param dec_h: decoder hidden state (B, dec_dim)
        :param att_prev: previous attention scores (att_w) (B,Tmax)
        :param float scaling: scaling parameter before applying softmax

        :return: attention weighted decoder state, context  (B, dec_dim)
        :rtype: torch.Tensor
        :return: attention score (B, Tmax)
        :rtype: torch.Tensor
        """
        batch_size = comp_listner_feat.size(0)

        # pre-compute all h outside the decoder loop
        if self.precomp_enc_h is None:
            self.enc_h = comp_listner_feat  # (B, T, enc_dim)
            self.enc_len_max = self.enc_h.size(1)
            # Only calculate once
            self.precomp_enc_h = self.mlp_enc(self.enc_h)  # (B, T, att_dim)

        dec_h = dec_h.view(batch_size, self.dec_dim)

        # initialize attention weight with uniform dist.
        if att_prev is None:
            att_prev = to_device(self, (~make_pad_mask(enc_lens)).float())
            # att_prev = att_prev / att_prev.sum(dim=1).unsqueeze(-1)
            att_prev = att_prev / enc_lens.float().unsqueeze(-1)  # (B, T)

        att_conv = self.loc_conv(
            att_prev.view(batch_size, 1, 1, self.enc_len_max))  # (B, C, 1, T)
        att_conv = att_conv.squeeze(2).transpose(1, 2)  # (B, T, C)
        att_conv = self.mlp_att(att_conv)  # (B, T, att_dim)
        dec_h = self.mlp_dec(dec_h).view(batch_size, 1,
                                         self.att_dim)  # (B, 1, att_dim)
        e = self.gvec(torch.tanh(att_conv + self.precomp_enc_h +
                                 dec_h)).squeeze(2)  # (B, T)

        if self.mask is None:
            self.mask = to_device(self, make_pad_mask(enc_lens))

        # masked out e according to mask
        e.masked_fill_(self.mask, -float('inf'))
        w = F.softmax(scaling * e, dim=1)  #(B, T)
        c = torch.bmm(w.view(batch_size, 1, self.enc_len_max),
                      self.enc_h).squeeze(1)  # (B, enc_dim)

        return c, w
    def forward(self, xs_pad, enc_lens, prev_state=None):
        """RNNP forward

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, vgg_o_dim)
        :param list of int enc_lens
        :param torch.Tensor prev_state: batch of previous RNN states
        :return: batch of hidden state sequences (B, Tmax, odim)
        :rtype: torch.Tensor
        """
        hid_states = []

        for i in range(self.nlayers):
            xs_pack = pack_padded_sequence(xs_pad, enc_lens, batch_first=True)
            rnn = getattr(self, f"rnn{i}")
            rnn.flatten_parameters()
            if prev_state is not None:
                prev_state = reset_backward_rnn_state(prev_state)
            ys, states = rnn(xs_pack,
                             hx=None if prev_state is None else prev_state[i])
            hid_states.append(states)

            ys_pad, enc_lens = pad_packed_sequence(ys, batch_first=True)
            # ys_pad: (B, T, enc_dim)

            sub = self.sample_rate[i]
            if sub > 1:
                ys_pad = ys_pad[:, ::sub]
                enc_lens = torch.LongTensor(
                    [int(i + 1) // sub for i in enc_lens])
            projected = getattr(self, f"bt{i}")(ys_pad.contiguous().view(
                -1, ys_pad.size(2)))  #(B*T, proj_dim)
            xs_pad = torch.tanh(
                projected.view(ys_pad.size(0), ys_pad.size(1), -1))

        return xs_pad, to_device(self, enc_lens), hid_states
Ejemplo n.º 8
0
        def run_batch(self, cur_b, x, ilens, ys, olens, train):
            sos = ys[0].new([self.sos_id])
            eos = ys[0].new([self.eos_id])
            ys_out = [torch.cat([sos, y, eos], dim=0) for y in ys]
            olens += 2  # pad <sos> and <eos>

            y_true = torch.cat(ys_out)

            pred, enc_lens = self.asr_model(x, ilens)
            olens = to_device(self.asr_model, olens)
            pred = F.log_softmax(pred, dim=-1)  # (T, o_dim)

            loss = self.ctc_loss(
                pred.transpose(0, 1).contiguous(),
                y_true.cuda().to(dtype=torch.long),
                enc_lens.cpu().to(dtype=torch.long),
                olens.cpu().to(dtype=torch.long))

            if train:
                info = {'loss': loss.item()}
                # if self.global_step % 5 == 0:
                if self.global_step % 500 == 0:
                    self.probe_model(pred, ys)
                self.asr_opt.zero_grad()
                loss.backward()

            else:
                cer = self.metric_observer.batch_cal_er(
                    pred.detach(), ys, ['ctc'], ['cer'])['ctc_cer']
                wer = self.metric_observer.batch_cal_er(
                    pred.detach(), ys, ['ctc'], ['wer'])['ctc_wer']
                info = {'cer': cer, 'wer': wer, 'loss': loss.item()}

            return info
    def forward(self, xs_pad, ilens):
        """
        xs_pad: (B, Tmax, 83)  #83 is 80-dim fbank + 3-dim pitch
        ilens: torch.Tensor with size B
        """
        assert xs_pad.size(0) == ilens.size(0), "Batch size mismatch"
        batch_size = xs_pad.size(0)

        # Put data to device
        xs_pad = to_device(self, xs_pad)
        ilens = to_device(self, ilens)

        enc, enc_lens, _, _ = self.encoder(xs_pad, ilens)
        out = self.head(enc)

        return out, enc_lens
    def enc_forward(self, xs_pad, enc_lens, probe_layer, prev_state=None):
        hid_states = []

        assert probe_layer <= self.nlayers,\
            f"Probe layer ({probe_layer}) can't exceed # of layers ({self.nlayers})"

        for i in range(probe_layer):
            xs_pack = pack_padded_sequence(xs_pad, enc_lens, batch_first=True)
            rnn = getattr(self, f"rnn{i}")
            rnn.flatten_parameters()
            if prev_state is not None:
                prev_state = reset_backward_rnn_state(prev_state)
            ys, states = rnn(xs_pack,
                             hx=None if prev_state is None else prev_state[i])
            hid_states.append(states)

            ys_pad, enc_lens = pad_packed_sequence(ys, batch_first=True)
            # ys_pad: (B, T, enc_dim)

            sub = self.sample_rate[i]
            if sub > 1:
                ys_pad = ys_pad[:, ::sub]
                enc_lens = torch.LongTensor(
                    [int(i + 1) // sub for i in enc_lens])
            projected = getattr(self, f"bt{i}")(ys_pad.contiguous().view(
                -1, ys_pad.size(2)))  #(B*T, proj_dim)
            xs_pad = torch.tanh(
                projected.view(ys_pad.size(0), ys_pad.size(1), -1))

        return xs_pad, to_device(self, enc_lens), hid_states
    def beam_decode(self, x, ilen):
        assert x.size(0) == 1, "Batch size should be 1 in beam_decode"

        x = to_device(self, x)
        ilen = to_device(self, ilen)

        enc, enc_len, _, _ = self.encoder(x, ilen)
        out = self.head(enc)

        #beam_decode_ans = self._beam_search(out, enc_len, decode_beam_size)
        beam_decode_ans = self.tf_beam_decode(out,
                                              enc_len,
                                              decode_beam_size,
                                              nbest=decode_beam_size)
        #embed()

        return beam_decode_ans, enc_len
    def preprocess(self, ys, olens):

        #TODO: should we set ys_out_pad to eos ot ignore_id
        sos = ys[0].new([self.sos_id])
        eos = ys[0].new([self.eos_id])
        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_in_pad = to_device(self,pad_sequence(ys_in, padding_value=self.eos_id, batch_first=False)) # L*B
        ys_in_pad = self.pre_embed(ys_in_pad) # L * B * d_model
        #TODO: should we add logit_scale of sqrt(d_model) here??
        ys_in_pad = self.pos_encoder(ys_in_pad)

        ys_out = [torch.cat([y, eos], dim=0) for y in ys]
        ys_out_pad = to_device(self,pad_sequence(ys_out, padding_value=IGNORE_ID, batch_first=True)) # B*L
        # ys_out_pad = pad_sequence(ys, padding_value=self.eos_id, batch_first=True) # B * L

        olens += 1

        return ys_in_pad, ys_out_pad, olens
    def extract_feat(self, xs_pad, ilens):
        batch_size = len(ilens)
        xs_pad = to_device(self,xs_pad.view(xs_pad.size(0), xs_pad.size(1), 1, xs_pad.size(2)).transpose(1,2))
        enc_pad = self.feat_extractor(xs_pad) # B * T * D' (128 * 20)
        enc_lens = torch.floor(ilens.to(dtype=torch.float32)/4).to(dtype=torch.int64)
        enc_pad = enc_pad.transpose(1,2)
        enc_pad = enc_pad.contiguous().view(batch_size,  enc_pad.size(1), enc_pad.size(2) * enc_pad.size(3))
        enc_pad = self.vgg2enc(enc_pad) # B * T * d_model

        return enc_pad, enc_lens
Ejemplo n.º 14
0
    def preprocess(self, ys):
        """Generate decoder input and output label from padded_input
        Add <sos> to decoder input, and add <eos> to decoder output label
        ys: list, len B
        olens: B
        """
        # ys = [y[y != IGNORE_ID] for y in padded_input]  # parse padded ys
        # prepare input and output word sequences with sos/eos IDs
        sos = ys[0].new([self.sos_id])
        eos = ys[0].new([self.eos_id])

        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        # padding for ys with -1
        # pys: utt x olen
        ys_in_pad = pad_list(ys_in, self.eos_id)
        ys_out_pad = pad_list(ys_out, IGNORE_ID)
        ys_in_pad = to_device(self, ys_in_pad)
        ys_out_pad = to_device(self, ys_out_pad)

        assert ys_in_pad.size() == ys_out_pad.size()
        return ys_in_pad, ys_out_pad
    def forward(self, xs_pad, ilens, prev_state=None):
        """Encoder forward

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, 83)
        :param torch.IntTensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)

        :return: batch of hidden state sequences (B, Tmax, odim)
        :rtype: torch.Tensor
        """
        out, enc_lens = self.vgg(xs_pad, ilens)
        out, enc_lens, cur_state = self.blstm(out, enc_lens, prev_state)
        mask = to_device(self, make_pad_mask(enc_lens).unsqueeze(-1))

        return out.masked_fill(mask, .0), enc_lens, cur_state
Ejemplo n.º 16
0
    def extract_feature(self, xs_pad, ilens):
        batch_size = xs_pad.size(0)

        xs_pad = to_device(self, xs_pad)
        xs_pad = xs_pad.view(batch_size, xs_pad.size(1), 1,
                             xs_pad.size(2)).transpose(1,
                                                       2)  # B * 1 * T * D(83)
        xs_pad = self.feat_extractor(xs_pad)  # B * T * D' (128 * 20)
        ilens = torch.floor(ilens.to(dtype=torch.float32) /
                            4).to(dtype=torch.int64)
        xs_pad = xs_pad.transpose(1, 2)
        xs_pad = xs_pad.contiguous().view(batch_size, xs_pad.size(1),
                                          xs_pad.size(2) * xs_pad.size(3))

        return xs_pad, ilens
    def forward(self, xs_pad, ilens, prev_state=None):
        xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), 1,
                             xs_pad.size(2)).transpose(1, 2)
        xs_pad = self.vgg(xs_pad)

        if torch.is_tensor(ilens):
            ilens = ilens.cpu().numpy()
        else:
            ilens = np.array(ilens, dtype=np.float32)
        enc_lens = np.array(np.ceil(ilens / 2), dtype=np.int64)
        enc_lens = np.array(np.ceil(np.array(enc_lens, dtype=np.float32) / 2),
                            dtype=np.int64).tolist()

        xs_pad = xs_pad.transpose(1, 2)
        xs_pad = xs_pad.contiguous().view(xs_pad.size(0), xs_pad.size(1),
                                          xs_pad.size(2) * xs_pad.size(3))
        vgg_out = xs_pad.clone()

        out, enc_lens, cur_state = self.blstm(xs_pad, enc_lens, prev_state)
        mask = to_device(self, make_pad_mask(enc_lens).unsqueeze(-1))
        return out.masked_fill(mask, .0), enc_lens, cur_state, vgg_out
Ejemplo n.º 18
0
    def forward(self, xs_pad, ilens, ys_pad, decode_step, tf=False, tf_rate=0.0, vis_att=False):
        """
        xs_pad: (B, Tmax, 83)  #83 is 80-dim fbank + 3-dim pitch
        ilens: torch.Tensor with size B
        ys_pad: (B, Lmax), will parse in list of tensors in ctc
        """
        # NOTE: in each forward, we at most visualize one utterance attention
        # map (since I want to compare different length effect), and choose the
        # longest utterance in this batch for simplicity

        assert xs_pad.size(0) == ilens.size(0) == ys_pad.size(0), "Batch size mismatch"
        batch_size = xs_pad.size(0)

        # Put data to device
        xs_pad = to_device(self,xs_pad)
        ilens = to_device(self,ilens)
        ys_pad = to_device(self, ys_pad)

        enc, enc_lens, _ = self.encoder(xs_pad, ilens)
        ctc_output = self.ctc_layer(enc)

        sos = ys_pad.new([self.sos_id])
        eos = ys_pad.new([self.eos_id])

        if tf:
            # ys_pad_in = torch.cat((sos.repeat(batch_size).view(batch_size,-1),ys_pad),1)
            # teacher = self.embed(ys_pad_in) #(B, L+1, dec_dim)
            teacher = self.embed(ys_pad) #(B, L, dec_dim)

        self.decoder.init_rnn(enc)
        self.attention.reset()

        last_char_emb = self.embed(sos.repeat(batch_size))# B * dec_dim

        output_char_seq = list()
        if vis_att:
            output_att_seq = list()

        att_w = None

        for t in range(decode_step):
            att_c, att_w = self.attention(enc, enc_lens, self.decoder.state_list[0], att_w)

            dec_inp = torch.cat([last_char_emb, att_c], dim=-1) # (B, dec_dim + enc_o_dim)
            dec_out = self.decoder(dec_inp) #(B, dec_dim)
            cur_char = self.char_trans(dec_out) #(B, odim)

            # Teacher forcing
            if tf and t < decode_step - 1:
                # if random.random() < tf_rate:
                if torch.rand(1).item() < tf_rate:
                    last_char_emb = teacher[:,t,:]
                else: # scheduled sampling
                    sampled_char = Categorical(F.softmax(cur_char,dim=-1)).sample()
                    last_char_emb = self.embed(sampled_char)
            else: # greedy pick
                last_char_emb = self.embed(torch.argmax(cur_char,dim=-1))

            output_char_seq.append(cur_char)
            if vis_att:
                output_att_seq.append(att_w.detach()[0].view(enc_lens[0],1))

        att_output = torch.stack(output_char_seq, dim=1).view(batch_size*decode_step,-1)
        # att_output = torch.stack(output_char_seq, dim=1) # (B,T,o_dim)
        att_map = torch.stack(output_att_seq, dim=1) if vis_att else None
        # att_map = torch.stack(output_att_seq,dim=1) # (T,L)

        return ctc_output, enc_lens, att_output, att_map