def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): '''AttMultiHeadDot forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: decoder hidden state (B x D_dec) :param Variable att_prev: dummy (does not use) :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B x D_enc) :rtype: Variable :return: list of previous attentioin weight (B x T_max) * aheads :rtype: list ''' batch = enc_hs_pad.size(0) # pre-compute all k and v outside the decoder loop if self.pre_compute_k is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_k = [ torch.tanh(linear_tensor(self.mlp_k[h], self.enc_h)) for h in six.moves.range(self.aheads) ] if self.pre_compute_v is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_v = [ linear_tensor(self.mlp_v[h], self.enc_h) for h in six.moves.range(self.aheads) ] if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) c = [] w = [] for h in six.moves.range(self.aheads): e = torch.sum(self.pre_compute_k[h] * torch.tanh( self.mlp_q[h](dec_z)).view(batch, 1, self.att_dim_k), dim=2) # utt x frame w += [F.softmax(self.scaling * e, dim=1)] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c += [ torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1) ] # concat all of c c = self.mlp_o(torch.cat(c, dim=1)) return c, w
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): '''AttLoc forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param Variable att_prev: dummy (does not use) :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: previous attentioin weights (B x T_max) :rtype: Variable ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor(self.gvec, torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, w
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): '''AttCovLoc forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param list att_prev_list: list of previous attetion weight :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: list of previous attentioin weights :rtype: list ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev_list is None: att_prev = [ Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l)) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev_list = [pad_list(att_prev, 0)] # att_prev_list: L' * [B x T] => cov_vec B x T cov_vec = sum(att_prev_list) # cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = linear_tensor(self.mlp_att, att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor( self.gvec, torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) w = F.softmax(scaling * e, dim=1) att_prev_list += [w] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, att_prev_list
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0): '''AttLocRec forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param tuple att_prev_states: previous attetion weight and lstm states ((B, T_max), ((B, att_dim), (B, att_dim))) :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: previous attention weights and lstm states (w, (hx, cx)) ((B, T_max), ((B, att_dim), (B, att_dim))) :rtype: tuple ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) if att_prev_states is None: # initialize attention weight with uniform dist. att_prev = [ Variable(enc_hs_pad.data.new(l).fill_(1.0 / l)) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev = pad_list(att_prev, 0) # initialize lstm states att_h = Variable(enc_hs_pad.data.new(batch, self.att_dim).zero_()) att_c = Variable(enc_hs_pad.data.new(batch, self.att_dim).zero_()) att_states = (att_h, att_c) else: att_prev = att_prev_states[0] att_states = att_prev_states[1] # B x 1 x 1 x T -> B x C x 1 x T att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # apply non-linear att_conv = F.relu(att_conv) # B x C x 1 x T -> B x C x 1 x 1 -> B x C att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1) att_h, att_c = self.att_lstm(att_conv, att_states) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor( self.gvec, torch.tanh( att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, (w, (att_h, att_c))
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): '''AttLoc2D forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param Variable att_prev: previous attetion weight (B x att_win x T_max) :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: previous attentioin weights (B x att_win x T_max) :rtype: Variable ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev is None: # B * [Li x att_win] att_prev = [ Variable( enc_hs_pad.data.new(l, self.att_win).zero_() + 1.0 / l) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev = pad_list(att_prev, 0).transpose(1, 2) # att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax att_conv = self.loc_conv(att_prev.unsqueeze(1)) # att_conv: B x C x 1 x Tmax -> B x Tmax x C att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = linear_tensor(self.mlp_att, att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor( self.gvec, torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) # update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax -> B x att_win x Tmax att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1) att_prev = att_prev[:, 1:] return c, att_prev
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): '''AttLoc forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: docoder hidden state (B x D_dec) :param Variable att_prev: previous attetion weight (B x T_max) :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B, D_enc) :rtype: Variable :return: previous attentioin weights (B x T_max) :rtype: Variable ''' batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev is None: att_prev = [ Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l)) for l in enc_hs_len ] # if no bias, 0 0-pad goes 0 att_prev = pad_list(att_prev, 0) # att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = linear_tensor(self.mlp_att, att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame # NOTE consider zero padding when compute w. e = linear_tensor( self.gvec, torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) ## added by bliu sigmoid firstly if self.aact_fuc == 'softmax': w = F.softmax(scaling * e, dim=1) elif self.aact_fuc == 'sigmoid': w = F.sigmoid(scaling * e) elif self.aact_fuc == 'sigmoid_softmax': e = torch.sigmoid(e) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, w
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): '''AttMultiHeadMultiResLoc forward :param Variable enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_h_len: padded encoder hidden state lenght (B) :param Variable dec_z: decoder hidden state (B x D_dec) :param Variable att_prev: list of previous attentioin weight (B x T_max) * aheads :param float scaling: scaling parameter before applying softmax :return: attentioin weighted encoder state (B x D_enc) :rtype: Variable :return: list of previous attentioin weight (B x T_max) * aheads :rtype: list ''' batch = enc_hs_pad.size(0) # pre-compute all k and v outside the decoder loop if self.pre_compute_k is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_k = [ linear_tensor(self.mlp_k[h], self.enc_h) for h in six.moves.range(self.aheads) ] if self.pre_compute_v is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_v = [ linear_tensor(self.mlp_v[h], self.enc_h) for h in six.moves.range(self.aheads) ] if dec_z is None: dec_z = Variable(enc_hs_pad.data.new(batch, self.dunits).zero_()) else: dec_z = dec_z.view(batch, self.dunits) if att_prev is None: att_prev = [] for h in six.moves.range(self.aheads): att_prev += [[ Variable(enc_hs_pad.data.new(l).zero_() + (1.0 / l)) for l in enc_hs_len ]] # if no bias, 0 0-pad goes 0 att_prev[h] = pad_list(att_prev[h], 0) c = [] w = [] for h in six.moves.range(self.aheads): att_conv = self.loc_conv[h](att_prev[h].view( batch, 1, 1, self.h_length)) att_conv = att_conv.squeeze(2).transpose(1, 2) att_conv = linear_tensor(self.mlp_att[h], att_conv) e = linear_tensor( self.gvec[h], torch.tanh(self.pre_compute_k[h] + att_conv + self.mlp_q[h] (dec_z).view(batch, 1, self.att_dim_k))).squeeze(2) w += [F.softmax(self.scaling * e, dim=1)] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c += [ torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1) ] # concat all of c c = self.mlp_o(torch.cat(c, dim=1)) return c, w