def sample_one(self, input, input_dir, soft_score, soft_score_dir, state, state_dir, tmp_hiddens, tmp_hiddens_dir, contexts, mask, mask_dir): if self.config.global_emb: batch_size = contexts.size(0) a, b = self.embedding.weight.size() if soft_score is None: emb = self.embedding(input) else: emb1 = torch.bmm( soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b))) emb2 = self.embedding(input) gamma = F.sigmoid( self.gated1(emb1.squeeze()) + self.gated2(emb2.squeeze())) emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze() c, d = self.embedding_dir.weight.size() if soft_score_dir is None: emb_dir = self.embedding_dir(input_dir) else: emb1_dir = torch.bmm( soft_score_dir.unsqueeze(1), self.embedding_dir.weight.expand((batch_size, c, d))) emb2_dir = self.embedding_dir(input_dir) gamma_dir = F.sigmoid( self.gated1_dir(emb1_dir.squeeze()) + self.gated2_dir(emb2_dir.squeeze())) emb_dir = gamma_dir * emb1_dir.squeeze() + ( 1 - gamma_dir) * emb2_dir.squeeze() else: emb = self.embedding(input) emb_dir = self.embedding_dir(input_dir) output, state = self.rnn(emb, state) output_bk = output hidden, attn_weights = self.attention(output, contexts) if self.config.schmidt: hidden = models.schmidt(hidden, tmp_hiddens) output = self.compute_score(hidden) if self.config.mask: if mask is not None: output = output.scatter_(1, mask, -9999999999) output_dir, state_dir = self.rnn_dir(emb_dir, state_dir) output_dir_bk = output_dir hidden_dir, attn_weights_dir = self.attention_dir(output_dir, contexts) if self.config.schmidt: hidden_dir = models.schmidt(hidden_dir, tmp_hiddens_dir) output_dir = self.compute_score_dir(hidden_dir) if self.config.mask: if mask_dir is not None: output_dir = output_dir.scatter_(1, mask_dir, -9999999999) return output, output_dir, state, state_dir, attn_weights, attn_weights_dir, hidden, hidden_dir, emb, emb_dir, output_bk, output_dir_bk
def sample_one(self, input, soft_score, state, tmp_hiddens, contexts, mask): if self.config.global_emb: batch_size = contexts.size(0) a, b = self.embedding.weight.size() if soft_score is None: emb = self.embedding(input) else: emb1 = torch.bmm( soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b))) if not self.config.all_soft: emb2 = self.embedding(input) gamma = F.sigmoid( self.gated1(emb1.squeeze()) + self.gated2(emb2.squeeze())) emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze() else: emb = emb1.squeeze() else: emb = self.embedding(input) output, state = self.rnn(emb, state) hidden, attn_weigths = self.attention(output, contexts) if self.config.schmidt: hidden = models.schmidt(hidden, tmp_hiddens) output = self.compute_score(hidden, targets=None) # print(output) if self.config.mask: if mask is not None: output = output.scatter_(1, mask, -9999999999) return output, state, attn_weigths, hidden, emb
def forward(self, all_targets, init_state, contexts): inputs=all_targets[:-1] if not self.config.global_emb: embs = self.embedding(inputs) # 如果不用global_emb,那么就直接一个简单的emb给下一个step. outputs, state, attns = [], init_state, [] for emb in embs.split(1): # 对于每个step的target词语, 这个循环应该是针对训练集当中的n的(目标标签数目),就是把第一个排列成一个tuple(也就是这个batch里label最多的那个的label数目) output, state = self.rnn(emb.squeeze(0), state) output, attn_weights = self.attention(output, contexts) if self.config.schmidt: output = models.schmidt(output, outputs) output = self.dropout(output) if self.config.ct_recu: contexts = (1 - (attn_weights > 0.003).float()).unsqueeze(-1) * contexts # contexts= (1-attn_weights).unsqueeze(-1)*contexts outputs += [output] attns += [attn_weights] outputs = torch.stack(outputs) attns = torch.stack(attns) return outputs, state, embs, torch.ones((3,100)).cuda() #随便给一个 else: outputs, state, attns, global_embs = [], init_state, [], [] embs = self.embedding(inputs).split(1) # time_step [1,bs,embsize] max_time_step = len(embs) emb = embs[0] # 第一步BOS的embedding. output, state = self.rnn(emb.squeeze(0), state) output, attn_weights = self.attention(output, contexts) output = self.dropout(output) if self.score_fn.startswith('arc_margin'): soft_score = F.softmax(self.linear(output,all_targets[1,:])) # 第一步的概率分布也就是 bs,vocal这么大 else: soft_score = F.softmax(self.linear(output)) # 第一步的概率分布也就是 bs,vocal这么大 outputs += [output] attns += [attn_weights] batch_size = soft_score.size(0) a, b = self.embedding.weight.size() for i in range(max_time_step - 1): emb1 = torch.bmm(soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b))) # 对vocab上所有emb按得分的加权平均 if not self.config.all_soft: emb2 = embs[i + 1] # 下一步的输入词语的emb gamma = F.sigmoid(self.gated1(emb1.squeeze()) + self.gated2(emb2.squeeze())) # 对应文中13式子 emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze() else: gamma = torch.ones((3,100)).cuda() #随便给一个 emb = emb1.squeeze() global_embs += [emb] # 注意这个global和上面的错一个step,相当与emb的第一个就是目标输出了之后的结果,跟上面那种Outputs保持一致 output, state = self.rnn(emb, state) output, attn_weights = self.attention(output, contexts) output = self.dropout(output) if self.score_fn.startswith('arc_margin'): soft_score = F.softmax(self.linear(output,all_targets[i+2,:])) # 第一步的概率分布也就是 bs,vocal这么大 else: soft_score = F.softmax(self.linear(output)) # 第一步的概率分布也就是 bs,vocal这么大 outputs += [output] attns += [attn_weights] outputs = torch.stack(outputs) global_embs = torch.stack(global_embs) attns = torch.stack(attns) return outputs, state, global_embs, gamma