def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0): """Calculate all of attentions :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) :param torch.Tensor hlen: batch of lengths of hidden state sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :param int strm_idx: stream index for parallel speaker attention in multi-speaker case :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ # TODO(kan-bayashi): need to make more smart way ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys att_idx = min(strm_idx, len(self.att) - 1) # hlen should be list of integer hlen = list(map(int, hlen)) self.loss = None # prepare input and output word sequences with sos/eos IDs eos = ys[0].new([self.eos]) sos = ys[0].new([self.sos]) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] # padding for ys with -1 # pys: utt x olen ys_in_pad = pad_list(ys_in, self.eos) ys_out_pad = pad_list(ys_out, self.ignore_id) # get length info olength = ys_out_pad.size(1) # initialization c_list = [self.zero_state(hs_pad)] z_list = [self.zero_state(hs_pad)] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(hs_pad)) z_list.append(self.zero_state(hs_pad)) att_w = None att_ws = [] self.att[att_idx].reset() # reset pre-computation of h # pre-computation of embedding eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim # loop for an output sequence for i in six.moves.range(olength): att_c, att_w = self.att[att_idx](hs_pad, hlen, self.dropout_dec[0](z_list[0]), att_w) ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) att_ws.append(att_w) # convert to numpy array with the shape (B, Lmax, Tmax) att_ws = att_to_numpy(att_ws, self.att[att_idx]) return att_ws
def calculate_all_attentions(self, hs_pad, hlen, ys_pad): """Calculate all of attentions Args: hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D) hlen (torch.Tensor): batch of lengths of hidden state sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: att_ws (ndarray): attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). """ ys = [y[y != self.ignore_id] for y in ys_pad] hlen = list(map(int, hlen)) blank = ys[0].new([self.blank]) ys_in = [torch.cat([blank, y], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.blank) olength = ys_in_pad.size(1) c_list = [self.zero_state(hs_pad)] z_list = [self.zero_state(hs_pad)] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(hs_pad)) z_list.append(self.zero_state(hs_pad)) att_w = None att_ws = [] self.att[0].reset() eys = self.dropout_emb(self.embed(ys_in_pad)) for i in six.moves.range(olength): att_c, att_w = self.att[0](hs_pad, hlen, self.dropout_dec[0](z_list[0]), att_w) ey = torch.cat((eys[:, i, :], att_c), dim=1) z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) att_ws.append(att_w) att_ws = att_to_numpy(att_ws, self.att[0]) return att_ws
def calculate_all_attentions(self, hs_pad, hlens, ys_pad): """Calculate all of attentions. Args: hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D) hlens (torch.Tensor): batch of lengths of hidden state sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: att_ws (ndarray): attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). """ ys = [y[y != self.ignore_id] for y in ys_pad] hlens = list(map(int, hlens)) blank = ys[0].new([self.blank]) ys_in = [torch.cat([blank, y], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.blank) olength = ys_in_pad.size(1) att_ws = [] self.att[0].reset() eys = self.embed(ys_in_pad) state, att_w = self.init_state(eys) for i in range(olength): att_c, att_w = self.att[0]( hs_pad, hlens, self.dropout_dec[0](state[0][0]), att_w ) ey = torch.cat((eys[:, i, :], att_c), dim=1) _, state = self.rnn_forward(ey, state) att_ws.append(att_w) att_ws = att_to_numpy(att_ws, self.att[0]) return att_ws
def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None): """Calculate all of attentions :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) in multi-encoder case, list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] :param torch.Tensor hlen: batch of lengths of hidden state sequences (B) [in multi-encoder case, list of torch.Tensor, [(B), (B), ..., ] :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :param int strm_idx: stream index for parallel speaker attention in multi-speaker case :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) multi-encoder case => [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)] 3) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ # to support mutiple encoder asr mode, in single encoder mode, # convert torch.Tensor to List of torch.Tensor if self.num_encs == 1: hs_pad = [hs_pad] hlen = [hlen] # TODO(kan-bayashi): need to make more smart way ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys att_idx = min(strm_idx, len(self.att) - 1) # hlen should be list of list of integer hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)] self.loss = None # prepare input and output word sequences with sos/eos IDs eos = ys[0].new([self.eos]) sos = ys[0].new([self.sos]) if self.replace_sos: ys_in = [ torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys) ] else: ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] # padding for ys with -1 # pys: utt x olen ys_in_pad = pad_list(ys_in, self.eos) ys_out_pad = pad_list(ys_out, self.ignore_id) # get length info olength = ys_out_pad.size(1) # initialization c_list = [self.zero_state(hs_pad[0])] z_list = [self.zero_state(hs_pad[0])] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(hs_pad[0])) z_list.append(self.zero_state(hs_pad[0])) att_ws = [] if self.num_encs == 1: att_w = None self.att[att_idx].reset() # reset pre-computation of h else: att_w_list = [None] * (self.num_encs + 1) # atts + han att_c_list = [None] * (self.num_encs) # atts for idx in range(self.num_encs + 1): self.att[idx].reset( ) # reset pre-computation of h in atts and han # pre-computation of embedding eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim # loop for an output sequence for i in six.moves.range(olength): if self.num_encs == 1: att_c, att_w = self.att[att_idx]( hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w) att_ws.append(att_w) else: for idx in range(self.num_encs): att_c_list[idx], att_w_list[idx] = self.att[idx]( hs_pad[idx], hlen[idx], self.dropout_dec[0](z_list[0]), att_w_list[idx], ) hs_pad_han = torch.stack(att_c_list, dim=1) hlen_han = [self.num_encs] * len(ys_in) att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( hs_pad_han, hlen_han, self.dropout_dec[0](z_list[0]), att_w_list[self.num_encs], ) att_ws.append(att_w_list) ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) if self.num_encs == 1: # convert to numpy array with the shape (B, Lmax, Tmax) att_ws = att_to_numpy(att_ws, self.att[att_idx]) else: _att_ws = [] for idx, ws in enumerate(zip(*att_ws)): ws = att_to_numpy(ws, self.att[idx]) _att_ws.append(ws) att_ws = _att_ws return att_ws