def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
        '''NoAtt forward

        :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
        :param list enc_h_len: padded encoder hidden state lenght (B)
        :param Variable dec_z: dummy (does not use)
        :param Variable att_prev: dummy (does not use)
        :return: attentioin weighted encoder state (B, D_enc)
        :rtype: Variable
        :return: previous attentioin weights
        :rtype: Variable
        '''
        batch = len(enc_hs_pad)
        # pre-compute all h outside the decoder loop
        if self.pre_compute_enc_h is None:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)

        # initialize attention weight with uniform dist.
        if att_prev is None:
            att_prev = [
                Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l))
                for l in enc_hs_len
            ]
            # if no bias, 0 0-pad goes 0
            att_prev = pad_list(att_prev, 0)
            self.c = torch.sum(self.enc_h *
                               att_prev.view(batch, self.h_length, 1),
                               dim=1)

        return self.c, att_prev
Beispiel #2
0
    def forward(self, xs, ilens):
        '''VGG2L forward

        :param xs:
        :param ilens:
        :return:
        '''
        ##logging.info(self.__class__.__name__ + ' input lengths: ' + str(ilens))

        # x: utt x frame x dim
        # xs = F.pad_sequence(xs)

        # x: utt x 1 (input channel num) x frame x dim
        xs = xs.view(xs.size(0), xs.size(1), self.in_channel,
                     xs.size(2) // self.in_channel).transpose(1, 2)

        # NOTE: max_pool1d ?
        xs = F.relu(self.conv1_1(xs))
        xs = F.relu(self.conv1_2(xs))
        xs = F.max_pool2d(xs, 2, stride=2, ceil_mode=True)

        xs = F.relu(self.conv2_1(xs))
        xs = F.relu(self.conv2_2(xs))
        xs = F.max_pool2d(xs, 2, stride=2, ceil_mode=True)
        # change ilens accordingly
        # ilens = [_get_max_pooled_size(i) for i in ilens]
        ilens = np.array(np.ceil(np.array(ilens, dtype=np.float32) / 2),
                         dtype=np.int64)
        ilens = np.array(np.ceil(np.array(ilens, dtype=np.float32) / 2),
                         dtype=np.int64).tolist()

        # x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
        xs = xs.transpose(1, 2)
        xs = xs.contiguous().view(xs.size(0), xs.size(1),
                                  xs.size(2) * xs.size(3))
        xs = [xs[i, :ilens[i]] for i in range(len(ilens))]
        xs = pad_list(xs, 0.0)
        return xs, ilens
    def forward(self,
                enc_hs_pad,
                enc_hs_len,
                dec_z,
                att_prev_list,
                scaling=2.0):
        '''AttCovLoc forward

        :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
        :param list enc_h_len: padded encoder hidden state lenght (B)
        :param Variable dec_z: docoder hidden state (B x D_dec)
        :param list att_prev_list: list of previous attetion weight
        :param float scaling: scaling parameter before applying softmax
        :return: attentioin weighted encoder state (B, D_enc)
        :rtype: Variable
        :return: list of previous attentioin weights
        :rtype: list
        '''

        batch = len(enc_hs_pad)
        # pre-compute all h outside the decoder loop
        if self.pre_compute_enc_h is None:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h)

        if dec_z is None:
            dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_())
        else:
            dec_z = dec_z.view(batch, self.dunits)

        # initialize attention weight with uniform dist.
        if att_prev_list is None:
            att_prev = [
                Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l))
                for l in enc_hs_len
            ]
            # if no bias, 0 0-pad goes 0
            att_prev_list = [pad_list(att_prev, 0)]

        # att_prev_list: L' * [B x T] => cov_vec B x T
        cov_vec = sum(att_prev_list)

        # cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T
        att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length))
        # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
        att_conv = att_conv.squeeze(2).transpose(1, 2)
        # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
        att_conv = linear_tensor(self.mlp_att, att_conv)

        # dec_z_tiled: utt x frame x att_dim
        dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)

        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        # NOTE consider zero padding when compute w.
        e = linear_tensor(
            self.gvec,
            torch.tanh(att_conv + self.pre_compute_enc_h +
                       dec_z_tiled)).squeeze(2)

        w = F.softmax(scaling * e, dim=1)
        att_prev_list += [w]

        # weighted sum over flames
        # utt x hdim
        # NOTE use bmm instead of sum(*)
        c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)

        return c, att_prev_list
    def forward(self,
                enc_hs_pad,
                enc_hs_len,
                dec_z,
                att_prev_states,
                scaling=2.0):
        '''AttLocRec forward

        :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
        :param list enc_h_len: padded encoder hidden state lenght (B)
        :param Variable dec_z: docoder hidden state (B x D_dec)
        :param tuple att_prev_states: previous attetion weight and lstm states
                                      ((B, T_max), ((B, att_dim), (B, att_dim)))
        :param float scaling: scaling parameter before applying softmax
        :return: attentioin weighted encoder state (B, D_enc)
        :rtype: Variable
        :return: previous attention weights and lstm states (w, (hx, cx))
                 ((B, T_max), ((B, att_dim), (B, att_dim)))
        :rtype: tuple
        '''

        batch = len(enc_hs_pad)
        # pre-compute all h outside the decoder loop
        if self.pre_compute_enc_h is None:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h)

        if dec_z is None:
            dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_())
        else:
            dec_z = dec_z.view(batch, self.dunits)

        if att_prev_states is None:
            # initialize attention weight with uniform dist.
            att_prev = [
                Variable(enc_hs_pad.data.new(l).fill_(1.0 / l))
                for l in enc_hs_len
            ]
            # if no bias, 0 0-pad goes 0
            att_prev = pad_list(att_prev, 0)

            # initialize lstm states
            att_h = Variable(enc_hs_pad.data.new(batch, self.att_dim).zero_())
            att_c = Variable(enc_hs_pad.data.new(batch, self.att_dim).zero_())
            att_states = (att_h, att_c)
        else:
            att_prev = att_prev_states[0]
            att_states = att_prev_states[1]

        # B x 1 x 1 x T -> B x C x 1 x T
        att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
        # apply non-linear
        att_conv = F.relu(att_conv)
        # B x C x 1 x T -> B x C x 1 x 1 -> B x C
        att_conv = F.max_pool2d(att_conv,
                                (1, att_conv.size(3))).view(batch, -1)

        att_h, att_c = self.att_lstm(att_conv, att_states)

        # dec_z_tiled: utt x frame x att_dim
        dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)

        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        # NOTE consider zero padding when compute w.
        e = linear_tensor(
            self.gvec,
            torch.tanh(
                att_h.unsqueeze(1) + self.pre_compute_enc_h +
                dec_z_tiled)).squeeze(2)

        w = F.softmax(scaling * e, dim=1)

        # weighted sum over flames
        # utt x hdim
        # NOTE use bmm instead of sum(*)
        c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)

        return c, (w, (att_h, att_c))
    def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
        '''AttLoc2D forward

        :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
        :param list enc_h_len: padded encoder hidden state lenght (B)
        :param Variable dec_z: docoder hidden state (B x D_dec)
        :param Variable att_prev: previous attetion weight (B x att_win x T_max)
        :param float scaling: scaling parameter before applying softmax
        :return: attentioin weighted encoder state (B, D_enc)
        :rtype: Variable
        :return: previous attentioin weights (B x att_win x T_max)
        :rtype: Variable
        '''

        batch = len(enc_hs_pad)
        # pre-compute all h outside the decoder loop
        if self.pre_compute_enc_h is None:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h)

        if dec_z is None:
            dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_())
        else:
            dec_z = dec_z.view(batch, self.dunits)

        # initialize attention weight with uniform dist.
        if att_prev is None:
            # B * [Li x att_win]
            att_prev = [
                Variable(
                    enc_hs_pad.data.new(l, self.att_win).zero_() + 1.0 / l)
                for l in enc_hs_len
            ]
            # if no bias, 0 0-pad goes 0
            att_prev = pad_list(att_prev, 0).transpose(1, 2)

        # att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax
        att_conv = self.loc_conv(att_prev.unsqueeze(1))
        # att_conv: B x C x 1 x Tmax -> B x Tmax x C
        att_conv = att_conv.squeeze(2).transpose(1, 2)
        # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
        att_conv = linear_tensor(self.mlp_att, att_conv)

        # dec_z_tiled: utt x frame x att_dim
        dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)

        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        # NOTE consider zero padding when compute w.
        e = linear_tensor(
            self.gvec,
            torch.tanh(att_conv + self.pre_compute_enc_h +
                       dec_z_tiled)).squeeze(2)

        w = F.softmax(scaling * e, dim=1)

        # weighted sum over flames
        # utt x hdim
        # NOTE use bmm instead of sum(*)
        c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)

        # update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax -> B x att_win x Tmax
        att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1)
        att_prev = att_prev[:, 1:]

        return c, att_prev
    def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
        '''AttLoc forward

        :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
        :param list enc_h_len: padded encoder hidden state lenght (B)
        :param Variable dec_z: docoder hidden state (B x D_dec)
        :param Variable att_prev: previous attetion weight (B x T_max)
        :param float scaling: scaling parameter before applying softmax
        :return: attentioin weighted encoder state (B, D_enc)
        :rtype: Variable
        :return: previous attentioin weights (B x T_max)
        :rtype: Variable
        '''

        batch = len(enc_hs_pad)
        # pre-compute all h outside the decoder loop
        if self.pre_compute_enc_h is None:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h)

        if dec_z is None:
            dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_())
        else:
            dec_z = dec_z.view(batch, self.dunits)

        # initialize attention weight with uniform dist.
        if att_prev is None:
            att_prev = [
                Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l))
                for l in enc_hs_len
            ]
            # if no bias, 0 0-pad goes 0
            att_prev = pad_list(att_prev, 0)

        # att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame
        att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
        # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
        att_conv = att_conv.squeeze(2).transpose(1, 2)
        # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
        att_conv = linear_tensor(self.mlp_att, att_conv)

        # dec_z_tiled: utt x frame x att_dim
        dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)

        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        # NOTE consider zero padding when compute w.
        e = linear_tensor(
            self.gvec,
            torch.tanh(att_conv + self.pre_compute_enc_h +
                       dec_z_tiled)).squeeze(2)

        ## added by bliu sigmoid firstly
        if self.aact_fuc == 'softmax':
            w = F.softmax(scaling * e, dim=1)
        elif self.aact_fuc == 'sigmoid':
            w = F.sigmoid(scaling * e)
        elif self.aact_fuc == 'sigmoid_softmax':
            e = torch.sigmoid(e)
            w = F.softmax(scaling * e, dim=1)

        # weighted sum over flames
        # utt x hdim
        c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)

        return c, w
    def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
        '''AttMultiHeadMultiResLoc forward

        :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
        :param list enc_h_len: padded encoder hidden state lenght (B)
        :param Variable dec_z: decoder hidden state (B x D_dec)
        :param Variable att_prev: list of previous attentioin weight (B x T_max) * aheads
        :param float scaling: scaling parameter before applying softmax
        :return: attentioin weighted encoder state (B x D_enc)
        :rtype: Variable
        :return: list of previous attentioin weight (B x T_max) * aheads
        :rtype: list
        '''

        batch = enc_hs_pad.size(0)
        # pre-compute all k and v outside the decoder loop
        if self.pre_compute_k is None:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_k = [
                linear_tensor(self.mlp_k[h], self.enc_h)
                for h in six.moves.range(self.aheads)
            ]

        if self.pre_compute_v is None:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_v = [
                linear_tensor(self.mlp_v[h], self.enc_h)
                for h in six.moves.range(self.aheads)
            ]

        if dec_z is None:
            dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_())
        else:
            dec_z = dec_z.view(batch, self.dunits)

        if att_prev is None:
            att_prev = []
            for h in six.moves.range(self.aheads):
                att_prev += [[
                    Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l))
                    for l in enc_hs_len
                ]]
                # if no bias, 0 0-pad goes 0
                att_prev[h] = pad_list(att_prev[h], 0)

        c = []
        w = []
        for h in six.moves.range(self.aheads):
            att_conv = self.loc_conv[h](att_prev[h].view(
                batch, 1, 1, self.h_length))
            att_conv = att_conv.squeeze(2).transpose(1, 2)
            att_conv = linear_tensor(self.mlp_att[h], att_conv)

            e = linear_tensor(
                self.gvec[h],
                torch.tanh(self.pre_compute_k[h] + att_conv + self.mlp_q[h]
                           (dec_z).view(batch, 1, self.att_dim_k))).squeeze(2)
            w += [F.softmax(self.scaling * e, dim=1)]

            # weighted sum over flames
            # utt x hdim
            # NOTE use bmm instead of sum(*)
            c += [
                torch.sum(self.pre_compute_v[h] *
                          w[h].view(batch, self.h_length, 1),
                          dim=1)
            ]

        # concat all of c
        c = self.mlp_o(torch.cat(c, dim=1))

        return c, w
Beispiel #8
0
    def forward(self, hpad, hlen, ys, scheduled_sampling_rate):
        '''Decoder forward

        :param hs:
        :param ys:
        :return:
        '''
        hpad = mask_by_length(hpad, hlen, 0)
        hlen = list(map(int, hlen))

        self.loss = None
        # prepare input and output word sequences with sos/eos IDs
        eos = Variable(ys[0].data.new([self.eos]))
        sos = Variable(ys[0].data.new([self.sos]))
        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        # padding for ys with -1
        # pys: utt x olen
        pad_ys_in = pad_list(ys_in, self.eos)
        pad_ys_out = pad_list(ys_out, self.ignore_id)
        # get dim, length info
        batch = pad_ys_out.size(0)
        olength = pad_ys_out.size(1)
        ##logging.info(self.__class__.__name__ + ' input lengths:  ' + str(hlen))
        ##logging.info(self.__class__.__name__ + ' output lengths: ' + str([y.size(0) for y in ys_out]))

        # initialization
        c_list = [self.zero_state(hpad)]
        z_list = [self.zero_state(hpad)]
        for l in six.moves.range(1, self.dlayers):
            c_list.append(self.zero_state(hpad))
            z_list.append(self.zero_state(hpad))
        att_w = None
        z_all = []
        y_all = []
        self.att.reset()  # reset pre-computation of h

        # pre-computation of embedding
        eys = self.embed(pad_ys_in)  # utt x olen x zdim

        rnnlm_state_prev = None
        # loop for an output sequence
        for i in six.moves.range(olength):
            att_c, att_w = self.att(hpad, hlen, z_list[0], att_w)
            if random.random() < scheduled_sampling_rate and i > 0:
                topv, topi = y_i.topk(1)
                topi = topi.squeeze(1)
                ey_top = self.embed(topi)  # utt x zdim
                ey = torch.cat((ey_top, att_c), dim=1)  # utt x (zdim + hdim)
            else:
                topi = pad_ys_in[:, i]
                ey = torch.cat((eys[:, i, :], att_c),
                               dim=1)  # utt x (zdim + hdim)
            z_list[0], c_list[0] = self.decoder[0](ey, (z_list[0], c_list[0]))
            for l in six.moves.range(1, self.dlayers):
                z_list[l], c_list[l] = self.decoder[l](z_list[l - 1],
                                                       (z_list[l], c_list[l]))

            if self.fusion == 'deep_fusion' and self.rnnlm is not None:
                rnnlm_state, lm_scores = self.rnnlm.predict(
                    rnnlm_state_prev, topi)
                lm_state = rnnlm_state['h2']
                gi = F.sigmoid(self.gate_linear(lm_state))
                output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)
                rnnlm_state_prev = rnnlm_state
            elif self.fusion == 'cold_fusion' and self.rnnlm is not None:
                rnnlm_state, lm_scores = self.rnnlm.predict(
                    rnnlm_state_prev, topi)
                lm_state = F.relu(self.lm_linear(lm_scores))
                gi = F.sigmoid(
                    self.gate_linear(torch.cat((lm_state, z_list[-1]), dim=1)))
                output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)
                rnnlm_state_prev = rnnlm_state
            else:
                output_in = z_list[-1]
            y_i = self.output(output_in)
            y_all.append(y_i)
            z_all.append(z_list[-1])

        y_all = torch.stack(y_all, dim=0).transpose(0, 1).contiguous().view(
            batch * olength, -1)
        self.loss = F.cross_entropy(y_all,
                                    pad_ys_out.view(-1),
                                    ignore_index=self.ignore_id,
                                    size_average=True)
        # -1: eos, which is removed in the loss computation
        self.loss *= (np.mean([len(x) for x in ys_in]) - 1)

        acc = th_accuracy(y_all, pad_ys_out, ignore_label=self.ignore_id)
        if self.labeldist is not None:
            if self.vlabeldist is None:
                self.vlabeldist = to_cuda(
                    self, Variable(torch.from_numpy(self.labeldist)))
            loss_reg = -torch.sum(
                (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1),
                dim=0) / len(ys_in)
            self.loss = (
                1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg

        return self.loss, acc
Beispiel #9
0
    def calculate_all_attentions(self, hpad, hlen, ys):
        '''Calculate all of attentions

        :return: numpy array format attentions
        '''
        hlen = list(map(int, hlen))
        hpad = mask_by_length(hpad, hlen, 0)
        self.loss = None
        # prepare input and output word sequences with sos/eos IDs
        eos = Variable(ys[0].data.new([self.eos]))
        sos = Variable(ys[0].data.new([self.sos]))
        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        # padding for ys with -1
        # pys: utt x olen
        pad_ys_in = pad_list(ys_in, self.eos)
        pad_ys_out = pad_list(ys_out, self.ignore_id)

        # get length info
        olength = pad_ys_out.size(1)

        # initialization
        c_list = [self.zero_state(hpad)]
        z_list = [self.zero_state(hpad)]
        for l in six.moves.range(1, self.dlayers):
            c_list.append(self.zero_state(hpad))
            z_list.append(self.zero_state(hpad))
        att_w = None
        att_ws = []
        self.att.reset()  # reset pre-computation of h

        # pre-computation of embedding
        eys = self.embed(pad_ys_in)  # utt x olen x zdim
        rnnlm_state_prev = None

        # loop for an output sequence
        for i in six.moves.range(olength):
            att_c, att_w = self.att(hpad, hlen, z_list[0], att_w)
            if i > 0:
                topv, topi = y_i.topk(1)
                topi = topi.squeeze(1)
                ey_top = self.embed(topi)  # utt x zdim
                ey = torch.cat((ey_top, att_c), dim=1)  # utt x (zdim + hdim)
            else:
                topi = pad_ys_in[:, i]
                ey = torch.cat((eys[:, i, :], att_c),
                               dim=1)  # utt x (zdim + hdim)
            z_list[0], c_list[0] = self.decoder[0](ey, (z_list[0], c_list[0]))
            for l in six.moves.range(1, self.dlayers):
                z_list[l], c_list[l] = self.decoder[l](z_list[l - 1],
                                                       (z_list[l], c_list[l]))
            att_ws.append(att_w)

            if self.fusion == 'deep_fusion' and self.rnnlm is not None:
                rnnlm_state, lm_scores = self.rnnlm.predict(
                    rnnlm_state_prev, topi)
                lm_state = rnnlm_state['h2']
                gi = F.sigmoid(self.gate_linear(lm_state))
                output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)
                rnnlm_state_prev = rnnlm_state
            elif self.fusion == 'cold_fusion' and self.rnnlm is not None:
                rnnlm_state, lm_scores = self.rnnlm.predict(
                    rnnlm_state_prev, topi)
                lm_state = F.relu(self.lm_linear(lm_scores))
                gi = F.sigmoid(
                    self.gate_linear(torch.cat((lm_state, z_list[-1]), dim=1)))
                output_in = torch.cat((z_list[-1], gi * lm_state), dim=1)
                rnnlm_state_prev = rnnlm_state
            else:
                output_in = z_list[-1]
            y_i = self.output(output_in)

        # convert to numpy array with the shape (B, Lmax, Tmax)
        if isinstance(self.att, e2e_attention.AttLoc2D):
            # att_ws => list of previous concate attentions
            att_ws = torch.stack([aw[:, -1] for aw in att_ws],
                                 dim=1).data.cpu().numpy()
        elif isinstance(self.att,
                        (e2e_attention.AttCov, e2e_attention.AttCovLoc)):
            # att_ws => list of list of previous attentions
            att_ws = torch.stack([aw[-1] for aw in att_ws],
                                 dim=1).data.cpu().numpy()
        elif isinstance(self.att, e2e_attention.AttLocRec):
            # att_ws => list of tuple of attention and hidden states
            att_ws = torch.stack([aw[0] for aw in att_ws],
                                 dim=1).data.cpu().numpy()
        elif isinstance(
                self.att,
            (e2e_attention.AttMultiHeadDot, e2e_attention.AttMultiHeadAdd,
             e2e_attention.AttMultiHeadLoc,
             e2e_attention.AttMultiHeadMultiResLoc)):
            # att_ws => list of list of each head attetion
            n_heads = len(att_ws[0])
            att_ws_sorted_by_head = []
            for h in six.moves.range(n_heads):
                att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1)
                att_ws_sorted_by_head += [att_ws_head]
            att_ws = torch.stack(att_ws_sorted_by_head,
                                 dim=1).data.cpu().numpy()
        else:
            # att_ws => list of attetions
            att_ws = torch.stack(att_ws, dim=1).data.cpu().numpy()
        return att_ws