def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): '''NoAtt forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: dummy (does not use) :param Variable att_prev: dummy (does not use) :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: previous attentioin weights :rtype: Variable ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # initialize attention weight with uniform dist. if att_prev is None: att_prev = [ Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l)) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev = pad_list(att_prev, 0) self.c = torch.sum(self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1) return self.c, att_prev
def forward(self, xs, ilens): '''VGG2L forward :param xs: :param ilens: :return: ''' ##logging.info(self.__class__.__name__ + ' input lengths: ' + str(ilens)) # x: utt x frame x dim # xs = F.pad_sequence(xs) # x: utt x 1 (input channel num) x frame x dim xs = xs.view(xs.size(0), xs.size(1), self.in_channel, xs.size(2) // self.in_channel).transpose(1, 2) # NOTE: max_pool1d ? xs = F.relu(self.conv1_1(xs)) xs = F.relu(self.conv1_2(xs)) xs = F.max_pool2d(xs, 2, stride=2, ceil_mode=True) xs = F.relu(self.conv2_1(xs)) xs = F.relu(self.conv2_2(xs)) xs = F.max_pool2d(xs, 2, stride=2, ceil_mode=True) # change ilens accordingly # ilens = [_get_max_pooled_size(i) for i in ilens] ilens = np.array(np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64) ilens = np.array(np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist() # x: utt_list of frame (remove zeropaded frames) x (input channel num x dim) xs = xs.transpose(1, 2) xs = xs.contiguous().view(xs.size(0), xs.size(1), xs.size(2) * xs.size(3)) xs = [xs[i, :ilens[i]] for i in range(len(ilens))] xs = pad_list(xs, 0.0) return xs, ilens
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): '''AttCovLoc forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param list att_prev_list: list of previous attetion weight :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: list of previous attentioin weights :rtype: list ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev_list is None: att_prev = [ Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l)) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev_list = [pad_list(att_prev, 0)] # att_prev_list: L' * [B x T] => cov_vec B x T cov_vec = sum(att_prev_list) # cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = linear_tensor(self.mlp_att, att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor( self.gvec, torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) w = F.softmax(scaling * e, dim=1) att_prev_list += [w] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, att_prev_list
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0): '''AttLocRec forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param tuple att_prev_states: previous attetion weight and lstm states ((B, T_max), ((B, att_dim), (B, att_dim))) :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: previous attention weights and lstm states (w, (hx, cx)) ((B, T_max), ((B, att_dim), (B, att_dim))) :rtype: tuple ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) if att_prev_states is None: # initialize attention weight with uniform dist. att_prev = [ Variable(enc_hs_pad.data.new(l).fill_(1.0 / l)) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev = pad_list(att_prev, 0) # initialize lstm states att_h = Variable(enc_hs_pad.data.new(batch, self.att_dim).zero_()) att_c = Variable(enc_hs_pad.data.new(batch, self.att_dim).zero_()) att_states = (att_h, att_c) else: att_prev = att_prev_states[0] att_states = att_prev_states[1] # B x 1 x 1 x T -> B x C x 1 x T att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # apply non-linear att_conv = F.relu(att_conv) # B x C x 1 x T -> B x C x 1 x 1 -> B x C att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1) att_h, att_c = self.att_lstm(att_conv, att_states) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor( self.gvec, torch.tanh( att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, (w, (att_h, att_c))
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): '''AttLoc2D forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param Variable att_prev: previous attetion weight (B x att_win x T_max) :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: previous attentioin weights (B x att_win x T_max) :rtype: Variable ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev is None: # B * [Li x att_win] att_prev = [ Variable( enc_hs_pad.data.new(l, self.att_win).zero_() + 1.0 / l) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev = pad_list(att_prev, 0).transpose(1, 2) # att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax att_conv = self.loc_conv(att_prev.unsqueeze(1)) # att_conv: B x C x 1 x Tmax -> B x Tmax x C att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = linear_tensor(self.mlp_att, att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor( self.gvec, torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) # update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax -> B x att_win x Tmax att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1) att_prev = att_prev[:, 1:] return c, att_prev
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): '''AttLoc forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param Variable att_prev: previous attetion weight (B x T_max) :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: previous attentioin weights (B x T_max) :rtype: Variable ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev is None: att_prev = [ Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l)) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev = pad_list(att_prev, 0) # att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = linear_tensor(self.mlp_att, att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor( self.gvec, torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) ## added by bliu sigmoid firstly if self.aact_fuc == 'softmax': w = F.softmax(scaling * e, dim=1) elif self.aact_fuc == 'sigmoid': w = F.sigmoid(scaling * e) elif self.aact_fuc == 'sigmoid_softmax': e = torch.sigmoid(e) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, w
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): '''AttMultiHeadMultiResLoc forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: decoder hidden state (B x D_dec) :param Variable att_prev: list of previous attentioin weight (B x T_max) * aheads :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B x D_enc) :rtype: Variable :return: list of previous attentioin weight (B x T_max) * aheads :rtype: list ''' batch = enc_hs_pad.size(0) # pre-compute all k and v outside the decoder loop if self.pre_compute_k is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_k = [ linear_tensor(self.mlp_k[h], self.enc_h) for h in six.moves.range(self.aheads) ] if self.pre_compute_v is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_v = [ linear_tensor(self.mlp_v[h], self.enc_h) for h in six.moves.range(self.aheads) ] if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) if att_prev is None: att_prev = [] for h in six.moves.range(self.aheads): att_prev += [[ Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l)) for l in enc_hs_len ]] # if no bias, 0 0-pad goes 0 att_prev[h] = pad_list(att_prev[h], 0) c = [] w = [] for h in six.moves.range(self.aheads): att_conv = self.loc_conv[h](att_prev[h].view( batch, 1, 1, self.h_length)) att_conv = att_conv.squeeze(2).transpose(1, 2) att_conv = linear_tensor(self.mlp_att[h], att_conv) e = linear_tensor( self.gvec[h], torch.tanh(self.pre_compute_k[h] + att_conv + self.mlp_q[h] (dec_z).view(batch, 1, self.att_dim_k))).squeeze(2) w += [F.softmax(self.scaling * e, dim=1)] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c += [ torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1) ] # concat all of c c = self.mlp_o(torch.cat(c, dim=1)) return c, w
def forward(self, hpad, hlen, ys, scheduled_sampling_rate): '''Decoder forward :param hs: :param ys: :return: ''' hpad = mask_by_length(hpad, hlen, 0) hlen = list(map(int, hlen)) self.loss = None # prepare input and output word sequences with sos/eos IDs eos = Variable(ys[0].data.new([self.eos])) sos = Variable(ys[0].data.new([self.sos])) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] # padding for ys with -1 # pys: utt x olen pad_ys_in = pad_list(ys_in, self.eos) pad_ys_out = pad_list(ys_out, self.ignore_id) # get dim, length info batch = pad_ys_out.size(0) olength = pad_ys_out.size(1) ##logging.info(self.__class__.__name__ + ' input lengths: ' + str(hlen)) ##logging.info(self.__class__.__name__ + ' output lengths: ' + str([y.size(0) for y in ys_out])) # initialization c_list = [self.zero_state(hpad)] z_list = [self.zero_state(hpad)] for l in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(hpad)) z_list.append(self.zero_state(hpad)) att_w = None z_all = [] y_all = [] self.att.reset() # reset pre-computation of h # pre-computation of embedding eys = self.embed(pad_ys_in) # utt x olen x zdim rnnlm_state_prev = None # loop for an output sequence for i in six.moves.range(olength): att_c, att_w = self.att(hpad, hlen, z_list[0], att_w) if random.random() < scheduled_sampling_rate and i > 0: topv, topi = y_i.topk(1) topi = topi.squeeze(1) ey_top = self.embed(topi) # utt x zdim ey = torch.cat((ey_top, att_c), dim=1) # utt x (zdim + hdim) else: topi = pad_ys_in[:, i] ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) z_list[0], c_list[0] = self.decoder[0](ey, (z_list[0], c_list[0])) for l in six.moves.range(1, self.dlayers): z_list[l], c_list[l] = self.decoder[l](z_list[l - 1], (z_list[l], c_list[l])) if self.fusion == 'deep_fusion' and self.rnnlm is not None: rnnlm_state, lm_scores = self.rnnlm.predict( rnnlm_state_prev, topi) lm_state = rnnlm_state['h2'] gi = F.sigmoid(self.gate_linear(lm_state)) output_in = torch.cat((z_list[-1], gi * lm_state), dim=1) rnnlm_state_prev = rnnlm_state elif self.fusion == 'cold_fusion' and self.rnnlm is not None: rnnlm_state, lm_scores = self.rnnlm.predict( rnnlm_state_prev, topi) lm_state = F.relu(self.lm_linear(lm_scores)) gi = F.sigmoid( self.gate_linear(torch.cat((lm_state, z_list[-1]), dim=1))) output_in = torch.cat((z_list[-1], gi * lm_state), dim=1) rnnlm_state_prev = rnnlm_state else: output_in = z_list[-1] y_i = self.output(output_in) y_all.append(y_i) z_all.append(z_list[-1]) y_all = torch.stack(y_all, dim=0).transpose(0, 1).contiguous().view( batch * olength, -1) self.loss = F.cross_entropy(y_all, pad_ys_out.view(-1), ignore_index=self.ignore_id, size_average=True) # -1: eos, which is removed in the loss computation self.loss *= (np.mean([len(x) for x in ys_in]) - 1) acc = th_accuracy(y_all, pad_ys_out, ignore_label=self.ignore_id) if self.labeldist is not None: if self.vlabeldist is None: self.vlabeldist = to_cuda( self, Variable(torch.from_numpy(self.labeldist))) loss_reg = -torch.sum( (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0) / len(ys_in) self.loss = ( 1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg return self.loss, acc
def calculate_all_attentions(self, hpad, hlen, ys): '''Calculate all of attentions :return: numpy array format attentions ''' hlen = list(map(int, hlen)) hpad = mask_by_length(hpad, hlen, 0) self.loss = None # prepare input and output word sequences with sos/eos IDs eos = Variable(ys[0].data.new([self.eos])) sos = Variable(ys[0].data.new([self.sos])) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] # padding for ys with -1 # pys: utt x olen pad_ys_in = pad_list(ys_in, self.eos) pad_ys_out = pad_list(ys_out, self.ignore_id) # get length info olength = pad_ys_out.size(1) # initialization c_list = [self.zero_state(hpad)] z_list = [self.zero_state(hpad)] for l in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(hpad)) z_list.append(self.zero_state(hpad)) att_w = None att_ws = [] self.att.reset() # reset pre-computation of h # pre-computation of embedding eys = self.embed(pad_ys_in) # utt x olen x zdim rnnlm_state_prev = None # loop for an output sequence for i in six.moves.range(olength): att_c, att_w = self.att(hpad, hlen, z_list[0], att_w) if i > 0: topv, topi = y_i.topk(1) topi = topi.squeeze(1) ey_top = self.embed(topi) # utt x zdim ey = torch.cat((ey_top, att_c), dim=1) # utt x (zdim + hdim) else: topi = pad_ys_in[:, i] ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) z_list[0], c_list[0] = self.decoder[0](ey, (z_list[0], c_list[0])) for l in six.moves.range(1, self.dlayers): z_list[l], c_list[l] = self.decoder[l](z_list[l - 1], (z_list[l], c_list[l])) att_ws.append(att_w) if self.fusion == 'deep_fusion' and self.rnnlm is not None: rnnlm_state, lm_scores = self.rnnlm.predict( rnnlm_state_prev, topi) lm_state = rnnlm_state['h2'] gi = F.sigmoid(self.gate_linear(lm_state)) output_in = torch.cat((z_list[-1], gi * lm_state), dim=1) rnnlm_state_prev = rnnlm_state elif self.fusion == 'cold_fusion' and self.rnnlm is not None: rnnlm_state, lm_scores = self.rnnlm.predict( rnnlm_state_prev, topi) lm_state = F.relu(self.lm_linear(lm_scores)) gi = F.sigmoid( self.gate_linear(torch.cat((lm_state, z_list[-1]), dim=1))) output_in = torch.cat((z_list[-1], gi * lm_state), dim=1) rnnlm_state_prev = rnnlm_state else: output_in = z_list[-1] y_i = self.output(output_in) # convert to numpy array with the shape (B, Lmax, Tmax) if isinstance(self.att, e2e_attention.AttLoc2D): # att_ws => list of previous concate attentions att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).data.cpu().numpy() elif isinstance(self.att, (e2e_attention.AttCov, e2e_attention.AttCovLoc)): # att_ws => list of list of previous attentions att_ws = torch.stack([aw[-1] for aw in att_ws], dim=1).data.cpu().numpy() elif isinstance(self.att, e2e_attention.AttLocRec): # att_ws => list of tuple of attention and hidden states att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).data.cpu().numpy() elif isinstance( self.att, (e2e_attention.AttMultiHeadDot, e2e_attention.AttMultiHeadAdd, e2e_attention.AttMultiHeadLoc, e2e_attention.AttMultiHeadMultiResLoc)): # att_ws => list of list of each head attetion n_heads = len(att_ws[0]) att_ws_sorted_by_head = [] for h in six.moves.range(n_heads): att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1) att_ws_sorted_by_head += [att_ws_head] att_ws = torch.stack(att_ws_sorted_by_head, dim=1).data.cpu().numpy() else: # att_ws => list of attetions att_ws = torch.stack(att_ws, dim=1).data.cpu().numpy() return att_ws