Exemplo n.º 1
0
Arquivo: rnn.py Projeto: oucxlw/SDNet
    def sample_one(self, input, input_dir, soft_score, soft_score_dir, state,
                   state_dir, tmp_hiddens, tmp_hiddens_dir, contexts, mask,
                   mask_dir):
        if self.config.global_emb:
            batch_size = contexts.size(0)
            a, b = self.embedding.weight.size()
            if soft_score is None:
                emb = self.embedding(input)
            else:
                emb1 = torch.bmm(
                    soft_score.unsqueeze(1),
                    self.embedding.weight.expand((batch_size, a, b)))
                emb2 = self.embedding(input)
                gamma = F.sigmoid(
                    self.gated1(emb1.squeeze()) + self.gated2(emb2.squeeze()))
                emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze()

            c, d = self.embedding_dir.weight.size()
            if soft_score_dir is None:
                emb_dir = self.embedding_dir(input_dir)
            else:
                emb1_dir = torch.bmm(
                    soft_score_dir.unsqueeze(1),
                    self.embedding_dir.weight.expand((batch_size, c, d)))
                emb2_dir = self.embedding_dir(input_dir)
                gamma_dir = F.sigmoid(
                    self.gated1_dir(emb1_dir.squeeze()) +
                    self.gated2_dir(emb2_dir.squeeze()))
                emb_dir = gamma_dir * emb1_dir.squeeze() + (
                    1 - gamma_dir) * emb2_dir.squeeze()
        else:
            emb = self.embedding(input)
            emb_dir = self.embedding_dir(input_dir)

        output, state = self.rnn(emb, state)
        output_bk = output
        hidden, attn_weights = self.attention(output, contexts)
        if self.config.schmidt:
            hidden = models.schmidt(hidden, tmp_hiddens)
        output = self.compute_score(hidden)
        if self.config.mask:
            if mask is not None:
                output = output.scatter_(1, mask, -9999999999)

        output_dir, state_dir = self.rnn_dir(emb_dir, state_dir)
        output_dir_bk = output_dir
        hidden_dir, attn_weights_dir = self.attention_dir(output_dir, contexts)
        if self.config.schmidt:
            hidden_dir = models.schmidt(hidden_dir, tmp_hiddens_dir)
        output_dir = self.compute_score_dir(hidden_dir)
        if self.config.mask:
            if mask_dir is not None:
                output_dir = output_dir.scatter_(1, mask_dir, -9999999999)

        return output, output_dir, state, state_dir, attn_weights, attn_weights_dir, hidden, hidden_dir, emb, emb_dir, output_bk, output_dir_bk
Exemplo n.º 2
0
 def sample_one(self, input, soft_score, state, tmp_hiddens, contexts,
                mask):
     if self.config.global_emb:
         batch_size = contexts.size(0)
         a, b = self.embedding.weight.size()
         if soft_score is None:
             emb = self.embedding(input)
         else:
             emb1 = torch.bmm(
                 soft_score.unsqueeze(1),
                 self.embedding.weight.expand((batch_size, a, b)))
             if not self.config.all_soft:
                 emb2 = self.embedding(input)
                 gamma = F.sigmoid(
                     self.gated1(emb1.squeeze()) +
                     self.gated2(emb2.squeeze()))
                 emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze()
             else:
                 emb = emb1.squeeze()
     else:
         emb = self.embedding(input)
     output, state = self.rnn(emb, state)
     hidden, attn_weigths = self.attention(output, contexts)
     if self.config.schmidt:
         hidden = models.schmidt(hidden, tmp_hiddens)
     output = self.compute_score(hidden, targets=None)
     # print(output)
     if self.config.mask:
         if mask is not None:
             output = output.scatter_(1, mask, -9999999999)
     return output, state, attn_weigths, hidden, emb
Exemplo n.º 3
0
    def forward(self, all_targets, init_state, contexts):
        inputs=all_targets[:-1]
        if not self.config.global_emb:
            embs = self.embedding(inputs)  # 如果不用global_emb,那么就直接一个简单的emb给下一个step.
            outputs, state, attns = [], init_state, []
            for emb in embs.split(1):  # 对于每个step的target词语, 这个循环应该是针对训练集当中的n的(目标标签数目),就是把第一个排列成一个tuple(也就是这个batch里label最多的那个的label数目)
                output, state = self.rnn(emb.squeeze(0), state)
                output, attn_weights = self.attention(output, contexts)
                if self.config.schmidt:
                    output = models.schmidt(output, outputs)
                output = self.dropout(output)
                if self.config.ct_recu:
                    contexts = (1 - (attn_weights > 0.003).float()).unsqueeze(-1) * contexts
                    # contexts= (1-attn_weights).unsqueeze(-1)*contexts
                outputs += [output]
                attns += [attn_weights]
            outputs = torch.stack(outputs)
            attns = torch.stack(attns)
            return outputs, state, embs, torch.ones((3,100)).cuda() #随便给一个
        else:
            outputs, state, attns, global_embs = [], init_state, [], []
            embs = self.embedding(inputs).split(1)  # time_step [1,bs,embsize]
            max_time_step = len(embs)
            emb = embs[0]  # 第一步BOS的embedding.
            output, state = self.rnn(emb.squeeze(0), state)
            output, attn_weights = self.attention(output, contexts)
            output = self.dropout(output)
            if self.score_fn.startswith('arc_margin'):
                soft_score = F.softmax(self.linear(output,all_targets[1,:]))  # 第一步的概率分布也就是 bs,vocal这么大
            else:
                soft_score = F.softmax(self.linear(output))  # 第一步的概率分布也就是 bs,vocal这么大
            outputs += [output]
            attns += [attn_weights]

            batch_size = soft_score.size(0)
            a, b = self.embedding.weight.size()

            for i in range(max_time_step - 1):
                emb1 = torch.bmm(soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b)))  # 对vocab上所有emb按得分的加权平均
                if not self.config.all_soft:
                    emb2 = embs[i + 1]  # 下一步的输入词语的emb
                    gamma = F.sigmoid(self.gated1(emb1.squeeze()) + self.gated2(emb2.squeeze()))  # 对应文中13式子
                    emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze()
                else:
                    gamma = torch.ones((3,100)).cuda() #随便给一个
                    emb = emb1.squeeze()
                global_embs += [emb]  # 注意这个global和上面的错一个step,相当与emb的第一个就是目标输出了之后的结果,跟上面那种Outputs保持一致
                output, state = self.rnn(emb, state)
                output, attn_weights = self.attention(output, contexts)
                output = self.dropout(output)
                if self.score_fn.startswith('arc_margin'):
                    soft_score = F.softmax(self.linear(output,all_targets[i+2,:]))  # 第一步的概率分布也就是 bs,vocal这么大
                else:
                    soft_score = F.softmax(self.linear(output))  # 第一步的概率分布也就是 bs,vocal这么大
                outputs += [output]
                attns += [attn_weights]
            outputs = torch.stack(outputs)
            global_embs = torch.stack(global_embs)
            attns = torch.stack(attns)
            return outputs, state, global_embs, gamma