Exemple #1
0
    def translate_sentence(self, src_seq):
        """
        对一句话进行解码
        :param src_seq: 维度为[1, Lq]
        :return:
        """
        # TODO: expand to batch operation.
        assert src_seq.size(0) == 1

        src_pad_idx, trg_eos_idx = self.src_pad_idx, self.trg_eos_idx
        max_seq_len, beam_size, alpha = self.max_seq_len, self.beam_size, self.alpha

        with torch.no_grad():
            # 获取对src进行self-attention时的mask矩阵 [1, 1, Lq]
            src_mask = get_pad_mask(src_seq, src_pad_idx)
            # 得到初始单词的索引,[beam_size, Lq, d_model]和[beam_size, max_seq_len]和[1, 1, beam_size]
            # 其中gen_seq的第0列都是bos,第1列都是最大概率索引
            enc_output, gen_seq, scores = self._get_init_state(src_seq, src_mask)

            ans_idx = 0  # default
            # 直到最大长度
            for step in range(2, max_seq_len):  # decode up to max length
                # 输入为[beam_size, step]和[beam_size, Lq, d_model]和[1, 1, Lq]
                ## 输出为 [beam_size, step, vocab_size],表示beam中每一个词的概率分布
                dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask)
                # 更新gen_seq之前的位置以及当前的位置,更新scores中的值
                gen_seq, scores = self._get_the_best_score_and_idx(gen_seq, dec_output, scores, step)

                # Check if all path finished
                # -- locate the eos in the generated sequences
                ## 寻找是否存在结束标志符eos
                ## [beam_size, max_len]
                eos_locs = gen_seq == trg_eos_idx
                # -- replace the eos with its position for the length penalty use
                ## 通过mask获取每一个序列eos对应的长度值 [beam_size, ]
                seq_lens, _ = self.len_map.masked_fill(~eos_locs, max_seq_len).min(1)
                # -- check if all beams contain eos
                ## 检查是否所有的beams都包含eos,这是解码停止的条件
                if (eos_locs.sum(1) > 0).sum(0).item() == beam_size:
                    # TODO: Try different terminate conditions.
                    # 将概率值除以对长度的惩罚,由于是除以长度的指数,所以是希望长度越短越好
                    ## 先经过除法得到的结果维度为[1, beam_size],输出分数最大值对应的路径索引
                    _, ans_idx = scores.div(seq_lens.float() ** alpha).max(1)
                    ans_idx = ans_idx.item()
                    break
        # 返回路径最大值对应的解码结果
        return gen_seq[ans_idx][:seq_lens[ans_idx]].tolist()
    def forward(self, inputs):
        """
        NU表示utterance的轮次,Lu表示一个utterance的长度,Lr表示一个response的长度
        """
        # [B, Nu, Lu]
        utterances = inputs['utterances']
        # [B, Lr]
        response = inputs['response']
        bsz = utterances.size(0)
        # 获取表示turn的张量
        turns_num = inputs['turns']
        # 对utterance中的pad进行mask [B*Nu, 1, Lu]
        uttrs_mask = get_pad_mask(utterances.view(-1, self.uttr_len),
                                  self._params['padding_idx'])
        uttrs_embed_mask = uttrs_mask.squeeze(dim=-2).unsqueeze(
            dim=-1)  # [B*Nu, Lu, 1]
        # 对response中的pad进行mask [B, 1, Lr]
        resp_mask = get_pad_mask(response, self._params['padding_idx'])
        resp_embed_mask = resp_mask.squeeze(dim=-2).unsqueeze(
            dim=-1)  # [B, Lr, 1]

        # --------------- Embedding层 -----------------
        ## [B, Nu, Lu, embed_output_dim]
        uttrs_embedding = self.word_emb(utterances)
        ## [B, Lr, embed_output_dim]
        resp_embedding = self.word_emb(response)
        ## 是否加上turn embedding
        if self._params['is_turn_emb']:
            ### [1, Nu, 1, embed_output_dim]
            turns_embedding = self.turn_emb(turns_num).unsqueeze(dim=-2)
            uttrs_embedding = uttrs_embedding + turns_embedding

        if self._params['is_position_emb']:
            ### [B*Nu, Lu, embed_output_dim]
            uttrs_embedding = self.position_emb(
                uttrs_embedding.view(bsz * self.turns, self.uttr_len, -1))
            ### [B, Lr, embed_output_dim]
            resp_embedding = self.position_emb(resp_embedding)
        ## [B*Nu, Lu, d_model]
        # uttrs_embedding = self.word_proj(uttrs_embedding)
        # ## [B, Lr, d_model]
        # resp_embedding = self.word_proj(resp_embedding)
        U_emb = uttrs_embedding * uttrs_embed_mask
        R_emb = resp_embedding * resp_embed_mask

        # -------------------- Attentive HR Encoder -----------------
        ## [B*Nu, Lu, NL, 2*hid]
        U_stack = self.stack_brnn(U_emb, uttrs_embed_mask)
        ## [B, Lr, NL, 2*hid]
        R_stack = self.stack_brnn(R_emb, resp_embed_mask)
        ## [NL, 1]
        wm = F.softmax(self.w_m, dim=0).unsqueeze(-1)
        ## 对utterance各个层进行Attention [B*Nu, Lu, 2*hid]
        U_attn = torch.matmul(U_stack.transpose(2, 3), wm).squeeze(dim=-1)
        ## 对response的各个层进行Attention [B, Lr, 2*hid]
        R_attn = torch.matmul(R_stack.transpose(2, 3), wm).squeeze(dim=-1)

        # ------------------ Matching Layer -------------------------

        U_attn_reshape = U_attn.view(bsz, self.turns * self.uttr_len, -1)
        ## [B, Nu*Lu, 2*hid]
        U_R_attn, *_ = self.cross_attn(U_attn_reshape, R_attn, R_attn,
                                       resp_mask)
        ## [B*Nu, Lu, 2*hid]
        U_R_attn = U_R_attn.view(bsz * self.turns, self.uttr_len, -1)
        ## [B, Lr, 2*hid]
        R_U_attn, *_ = self.cross_attn(
            R_attn, U_attn_reshape, U_attn_reshape,
            uttrs_mask.squeeze(dim=1).view(bsz, self.turns *
                                           self.uttr_len).unsqueeze(dim=-2))
        ## [B*Nu, Lu, 8*hid]
        C_mat = self.get_matching_tensor(U_attn, U_R_attn)
        ## [B, Lr, 8*hid]
        R_mat = self.get_matching_tensor(R_attn, R_U_attn)

        # ----------------- Aggregation Layer -----------------------

        ## [B*Nu, Lu, 2*hid]以及[2, B*Nu, hid]
        C_out, C_state = self.sent_gru(C_mat)
        ## [B, Lr, 2*hid]以及[2, B, hid]
        R_out, R_state = self.sent_gru(R_mat)
        ## 最大池化和平均池化
        ### 下面3个为 [B*Nu, 2*hid]
        C_mean = torch.mean(C_out, dim=1)
        C_max = torch.max(C_out, dim=1)[0]
        C_state = C_state.transpose(0,
                                    1).contiguous().view(bsz * self.turns, -1)
        ## [B, Nu, 6*hid]
        C_out = torch.cat([C_mean, C_max, C_state],
                          dim=-1).view(bsz, self.turns, -1)
        ## [B, Nu, 2*hid]以及[2, B, hid]
        C_out, C_state = self.uttr_gru(C_out)
        ## utterance和response的聚合
        C_mean = torch.mean(C_out, dim=1)  # [B, 2*hid]
        C_max = torch.max(C_out, dim=1)[0]  # [B, 2*hid]
        C_state = C_state.transpose(0, 1).contiguous().view(bsz, -1)

        R_mean = torch.mean(R_out, dim=1)  # [B, 2*hid]
        R_max = torch.max(R_out, dim=1)[0]  # [B, 2*hid]
        R_state = R_state.transpose(0, 1).contiguous().view(bsz, -1)

        ## [B, 12*hid]
        M_agr = torch.cat([C_mean, C_max, C_state, R_mean, R_max, R_state],
                          dim=-1)

        # ------------------ Output Layer ---------------------
        output = self.dropout(M_agr)
        output = self.mlps(output)
        output = self.output(output)
        return output
Exemple #3
0
    def forward(self, inputs):
        """
        NU表示utterance的轮次,Lu表示一个utterance的长度,Lr表示一个response的长度
        """
        # ------------- 对utterance进行处理 -------------------
        ## [B, NU, Lu] 这里的Lu表示每一个utterance经过padding的长度
        utterances = inputs['utterances']
        bsz = utterances.size(0)  ## 获取当前的batch_size
        ## [B*Nu, 1, Lu]
        uttrs_mask = get_pad_mask(utterances,
                                  self._params['padding_idx']).view(
                                      bsz * self.turns, 1, self.uttr_len)
        ## [B*Nu, Lu, 1]
        uttrs_emb_mask = uttrs_mask.squeeze(dim=-2).unsqueeze(dim=-1)
        ## [B, Nu, Lu, emb_output]
        uttrs_emb = self.word_emb(utterances)
        ## [B*Nu, Lu, d_model]
        uttrs_emb = self.word_proj(uttrs_emb)
        ## 定义用于表示turn的张量
        ### [1, Nu]
        turns_num = inputs['turns']
        if self._params['is_turn_emb']:
            ## [1, Nu, 1, d_model]
            turns_embedding = self.turn_emb(turns_num).unsqueeze(dim=-2)
            ## [B, Nu, Lu, d_model]
            uttrs_emb = uttrs_emb + turns_embedding
        ## [B*Nu, Lu, d_model]
        uttrs_emb = uttrs_emb.view(bsz * self.turns, self.uttr_len, -1)
        if self._params['is_position_emb']:
            ### [B*Nu, Lu, d_model]
            uttrs_emb = self.position_emb(uttrs_emb)
        ## [B*Nu, Lu, d_model],对embedding进行mask
        uttrs_emb = uttrs_emb * uttrs_emb_mask
        ## [NL, B*Nu, Lu, d_model],自编码
        uttrs_es, *_ = self.encoder(uttrs_emb, uttrs_mask)
        uttrs_es_list = [uttrs_emb] + uttrs_es
        L_layer = len(uttrs_es_list)
        ## [B*Nu*(NL+1), Lu, d_model]
        uttrs_stack = torch.stack(uttrs_es_list, dim=0).transpose(
            0, 1).contiguous().view(bsz * self.turns * L_layer, self.uttr_len,
                                    -1)
        # ---------------- 得到response的所有层的结果 -------------------
        # [B, Lr]
        response = inputs['response']
        ## [B, 1, Lr]
        resp_mask = get_pad_mask(response, self._params['padding_idx'])
        ## [B, Lr, 1]
        resp_emb_mask = resp_mask.squeeze(dim=-2).unsqueeze(dim=-1)
        ## [B, Lr, emb_output]
        resp_emb = self.word_emb(response)
        ## [B, Lr, d_model]
        resp_emb = self.word_proj(resp_emb)
        if self._params['is_position_emb']:
            resp_emb = self.position_emb(resp_emb)
        resp_emb = resp_emb * resp_emb_mask
        ## [NL, B, Lr, d_model]
        resp_es, *_ = self.encoder(resp_emb, resp_mask)
        ## [NL+1, B, Lr, d_model]
        rep_es_list = [resp_emb] + resp_es
        ## [NL+1, B, Lr, d_model]的tensor
        resp_stack = torch.stack(rep_es_list, dim=0)
        ## [B*Nu*(NL+1), Lr, d_model]
        resp_stack = resp_stack.repeat(1, self.turns, 1, 1).transpose(
            0, 1).contiguous().view(bsz * self.turns * L_layer, self.resp_len,
                                    -1)

        # -------------------- 计算 Cross Attention --------------------------
        ## [B*Nu*(NL+1), Lu, d_model]
        UR_att, *_ = self.cross_att(
            uttrs_stack, resp_stack, resp_stack,
            resp_mask.repeat(self.turns * L_layer, 1, 1))
        ## [B*Nu*(NL+1), Lr, d_model]
        RU_att, *_ = self.cross_att(resp_stack, uttrs_stack, uttrs_stack,
                                    uttrs_mask.repeat(L_layer, 1, 1))
        # ------- 计算self-att match --------
        ## [B, NU, NL+1, Lu, Lr]
        ### **这里因为要计算相似度,注意要使用激活函数,避免发生梯度爆炸**
        M_self = F.tanh(
            torch.matmul(uttrs_stack,
                         resp_stack.transpose(1,
                                              2)).view(bsz, self.turns,
                                                       L_layer, self.uttr_len,
                                                       self.resp_len))
        # ------- 计算cross-att match -------
        ## [B, NU, NL+1, Lu, Lr]
        ### **这里因为要计算相似度,注意要使用激活函数,避免发生梯度爆炸**
        M_cross = F.tanh(
            torch.matmul(UR_att,
                         RU_att.transpose(1, 2)).view(bsz, self.turns, L_layer,
                                                      self.uttr_len,
                                                      self.resp_len))
        # [B, 2(NL+1), NU, Lu, Lr]
        output = torch.cat([M_self, M_cross], dim=2).transpose(1, 2)

        # [B, conv1_channels, NU/2, Lu/4, Lr/4]
        output = self.pool1(
            self.cons_pad(self.conv1_activation(self.conv1(output))))
        # [B, conv2_channels, 1, 1, 1]
        output = self.pool2(self.conv2_activation(self.conv2(output)))
        # [B, conv2_channels]
        output = output.view(bsz, -1)
        output = self.dropout(output)
        # [B, fan_output]
        # output = self.mlps(output)
        # [B, num_classes] or [B, 1]
        output = self.output(output)
        return output
Exemple #4
0
    def forward(self, inputs):
        """
        NU表示utterance的轮次,Lu表示一个utterance的长度,Lr表示一个response的长度
        """
        # [B, Nu, Lu]
        utterances = inputs['utterances']
        # [B, Lr]
        response = inputs['response']
        bsz = utterances.size(0)
        # 获取用于表示turn的张量 [1, Nu]
        turns_num = inputs['turns']
        # 对utterance中的pad进行mask [B*Nu, 1, Lu]
        uttrs_mask = get_pad_mask(utterances.view(-1, self.uttr_len),
                                  self._params['padding_idx'])
        # 对response中的pad进行mask [B, 1, Lr]
        resp_mask = get_pad_mask(response, self._params['padding_idx'])
        # [B, Nu, Lu, embed_output_dim]
        word_embedding = self.word_emb(utterances)
        # [B*Nu, Lu, d_model]
        U_emb = self.word_proj(word_embedding)
        # [B, Lr, embed_output_dim]
        resp_embedding = self.word_emb(response)
        # [B, Lr, d_model]
        R_emb = self.word_proj(resp_embedding)
        # 加上turn embedding
        if self._params['is_turn_emb']:
            ## [1, Nu, 1, embed_output_dim]
            turns_embedding = self.turn_emb(turns_num).unsqueeze(dim=-2)
            ## [B, Nu, Lu, embed_output_dim]
            U_emb = U_emb + turns_embedding
        if self._params['is_position_emb']:
            # [B*Nu, Lu, embed_output_dim]
            U_emb = self.position_emb(
                U_emb.view(bsz * self.turns, self.uttr_len, -1))
            # [B, Lr, embed_output_dim]
            R_emb = self.position_emb(R_emb)

        U_emb = U_emb * (uttrs_mask.squeeze(dim=-2).unsqueeze(dim=-1))
        R_emb = R_emb * (resp_mask.squeeze(dim=-2).unsqueeze(dim=-1))

        # 1. 首先经过context-selector,选择有用的上下文
        ## [B, Nu, Lu, d_model]
        multi_context = self.context_selector(
            U_emb.view(-1, self.turns, self.uttr_len, self._params['d_model']),
            uttrs_mask, self._params['hop_k'])
        ## [B*Nu, Lu, d_model]
        multi_context = multi_context.view(-1, self.uttr_len,
                                           self._params['d_model'])
        ## [B*Nu, Lr, d_model]
        R_emb = R_emb.unsqueeze(dim=1).repeat(1, self.turns, 1,
                                              1).view(-1, self.resp_len,
                                                      self._params['d_model'])
        resp_mask = resp_mask.unsqueeze(dim=1).repeat(1, self.turns, 1,
                                                      1).view(
                                                          -1, 1, self.resp_len)

        # 2. 经过几层卷积层,提取特征
        ## [B*Nu, linear_out]
        V = self.UR_matching(multi_context,
                             R_emb,
                             U_mask=uttrs_mask,
                             R_mask=resp_mask)
        ## [B, Nu, linear_out]
        V = V.view(bsz, self.turns, -1)

        # 3. 经过GRU,并得到最终的分类结果
        ## [B, Nu, direction*hidden_size] 和 [direction, B, hidden_size]
        outputs, h = self.gru(V)
        ## [B, direction*hidden_size]
        output = h.transpose(0, 1).contiguous().view(bsz, -1)
        if self._params['maxpool_output']:
            ## [B, direction*hidden_size]
            maxpool_output = outputs.max(dim=1)[0]
            ## [B, 2*direction*hidden_size]
            output = torch.cat([output, maxpool_output], dim=-1)
        output = self.dropout(output)
        # [B, num_classes]
        output = self.output(output)
        return output
    def forward(self, inputs):
        # [B, Nu, Lu]
        utterances = inputs['utterances']
        # [B, Lr]
        response = inputs['response']
        bsz = utterances.size(0)
        # 获取用于albert的mask [B*Nu, Lu]
        uttrs_mask = get_pad_mask(utterances.view(-1, self.uttr_len),
                                  self._params['padding_idx']).squeeze(dim=-2)
        # 对response中的pad进行mask [B, Lr]
        resp_mask = get_pad_mask(response,
                                 self._params['padding_idx']).squeeze(dim=-2)
        utterances = utterances.view(bsz * self.turns, self.uttr_len)
        ## [B*Nu, Lu, d_model]
        U_emb = self.albert(input_ids=utterances, attention_mask=uttrs_mask)[0]
        ## [B, Lr, d_model]
        R_emb = self.albert(input_ids=response, attention_mask=resp_mask)[0]

        ## 将padding进行mask
        U_emb = U_emb * (uttrs_mask.unsqueeze(dim=-1))
        R_emb = R_emb * (resp_mask.unsqueeze(dim=-1))
        uttrs_mask = uttrs_mask.unsqueeze(dim=-2)
        resp_mask = resp_mask.unsqueeze(dim=-2)

        # 1. 首先经过context-selector,选择有用的上下文
        ## [B, Nu, Lu, d_model]
        multi_context = self.context_selector(
            U_emb.view(-1, self.turns, self.uttr_len, self._params['d_model']),
            uttrs_mask, self._params['hop_k'])
        ## [B*Nu, Lu, d_model]
        multi_context = multi_context.view(-1, self.uttr_len,
                                           self._params['d_model'])

        ## [B*Nu, Lu, d_model]
        R_emb = R_emb.unsqueeze(dim=1).repeat(1, self.turns, 1,
                                              1).view(-1, self.resp_len,
                                                      self._params['d_model'])
        resp_mask = resp_mask.unsqueeze(dim=1).repeat(1, self.turns, 1,
                                                      1).view(
                                                          -1, 1, self.resp_len)
        # 2. 经过几层卷积层提取特征
        ## [B*Nu, linear_out]
        V = self.UR_matching(multi_context,
                             R_emb,
                             U_mask=uttrs_mask,
                             R_mask=resp_mask)
        V = V.view(bsz, self.turns, -1)

        # 3. 经过GRU,并得到最终的分类结果
        ## [B, Nu, direction*hidden_size]和[direction, B, hidden_size]
        outputs, h = self.gru(V)
        ## [B, direction*hidden_size]
        output = h.transpose(0, 1).contiguous().view(bsz, -1)
        if self._params['maxpool_output']:
            ## [B, direction*hidden_size]
            maxpool_output = outputs.max(dim=1)[0]
            ## [B, 2*direction*hidden_size]
            output = torch.cat([output, maxpool_output], dim=-1)
        output = self.dropout(output)
        # [B, num_classes]
        output = self.output(output)
        return output
    def forward(self, inputs):
        # 首先获取输入的utterance和response
        ## [B, Nu, Lu]
        utterances = inputs[constants.UTTRS]
        ## [B, Lr]
        response = inputs[constants.RESP]

        bsz = utterances.size(0)
        ## 对utterances进行reshape [B*Nu, Lu]
        utterances = utterances.view(bsz * self.turns, self.uttr_len)
        ## 如果是uru的形式,获取UR的position标记
        if self.data_type == "uru":
            ## [B, Nu, Lu]
            ur_pos = inputs[constants.UR_POS]
            ur_pos = ur_pos.view(bsz * self.turns, self.uttr_len)

        # 获取utterance和response的mask
        ## [B*Nu, 1, Lu]
        uttrs_mask = get_pad_mask(utterances, self.params['padding_idx'])
        uttrs_albert_mask = uttrs_mask.squeeze(dim=-2)  ## [B*Nu, Lu]
        ## [B, 1, Lr]
        resp_mask = get_pad_mask(response, self.params['padding_idx'])
        resp_albert_mask = resp_mask.squeeze(dim=-2)  ## [B, Lr]

        ## 获取utterance和response的embeding
        U_emb = self.albert(
            input_ids=utterances,
            attention_mask=uttrs_albert_mask)[0]  ## [B*Nu, Lu, d_model]
        R_emb = self.albert(
            input_ids=response,
            attention_mask=resp_albert_mask)[0]  ## [B, Lr, d_model]

        if self.params['is_ur_embed'] and self.data_type == "uru":
            ur_emb = self.ur_embed(ur_pos)
            U_emb = U_emb + ur_emb

        ## 对输出的padding进行mask
        uttrs_embed_mask = uttrs_albert_mask.unsqueeze(dim=-1)
        resp_embed_mask = resp_albert_mask.unsqueeze(dim=-1)
        U_emb = U_emb * uttrs_embed_mask
        R_emb = R_emb * resp_embed_mask

        # ------------------ Attentive HR Encoder ----------------------
        ## [B*Nu, Lu, NL, 2*hid]
        U_stack = self.stack_brnn(U_emb, uttrs_embed_mask).squeeze(dim=-2)
        ## [B, Lr, NL, 2*hid]
        R_stack = self.stack_brnn(R_emb, resp_embed_mask).squeeze(dim=-2)
        # ## [NL, 1]
        # wm = F.softmax(self.w_m, dim=0).unsqueeze(-1)
        # ## 对utterance各个层进行Attention [B*Nu, Lu, 2*hid]
        # U_attn = torch.matmul(U_stack.transpose(2, 3), wm).squeeze(dim=-1)
        # ## 对response的各个层进行Attention [B, Lr, 2*hid]
        # R_attn = torch.matmul(R_stack.transpose(2, 3), wm).squeeze(dim=-1)

        U_attn = U_stack
        R_attn = R_stack

        # ------------------ Matching Layer -------------------------

        U_attn_reshape = U_attn.view(bsz, self.turns * self.uttr_len, -1)
        uttrs_mask_reshape = uttrs_mask.squeeze(dim=1).view(
            bsz, self.turns * self.uttr_len).unsqueeze(dim=-2)
        # ## [B, Nu*Lu, 2*hid]
        # U_R_attn, *_ = self.cross_attn(U_attn_reshape, R_attn, R_attn, resp_mask)
        # ## [B*Nu, Lu, 2*hid]
        # U_R_attn = U_R_attn.view(bsz*self.turns, self.uttr_len, -1)
        # ## [B, Lr, 2*hid]
        # R_U_attn, *_ = self.cross_attn(R_attn, U_attn_reshape, U_attn_reshape, uttrs_mask.squeeze(dim=1).view(bsz, self.turns*self.uttr_len).unsqueeze(dim=-2))
        ## [B, Nu*Lu, 2*hid]和[B, Lr, 2*hid]
        U_R_attn, R_U_attn = self.attention(U_attn_reshape,
                                            R_attn,
                                            U_mask=uttrs_mask_reshape,
                                            R_mask=resp_mask)
        U_R_attn = U_R_attn.view(bsz * self.turns, self.uttr_len, -1)

        ## [B*Nu, Lu, 8*hid]
        C_mat = self.get_matching_tensor(U_attn, U_R_attn)
        ## [B, Lr, 8*hid]
        R_mat = self.get_matching_tensor(R_attn, R_U_attn)

        # ----------------- Aggregation Layer -----------------------

        ## [B*Nu, Lu, 2*hid]以及[2, B*Nu, hid]
        C_out, C_state = self.sent_gru(C_mat)
        ## [B, Lr, 2*hid]以及[2, B, hid]
        R_out, R_state = self.sent_gru(R_mat)
        ## 最大池化和平均池化
        ### 下面3个为 [B*Nu, 2*hid]
        C_mean = torch.mean(C_out, dim=1)
        C_max = torch.max(C_out, dim=1)[0]
        C_state = C_state.transpose(0,
                                    1).contiguous().view(bsz * self.turns, -1)
        ## [B, Nu, 6*hid]
        C_out = torch.cat([C_mean, C_max, C_state],
                          dim=-1).view(bsz, self.turns, -1)
        ## [B, Nu, 2*hid]以及[2, B, hid]
        C_out, C_state = self.uttr_gru(C_out)
        ## utterance和response的聚合
        C_mean = torch.mean(C_out, dim=1)  # [B, 2*hid]
        C_max = torch.max(C_out, dim=1)[0]  # [B, 2*hid]
        C_state = C_state.transpose(0, 1).contiguous().view(bsz, -1)

        R_mean = torch.mean(R_out, dim=1)  # [B, 2*hid]
        R_max = torch.max(R_out, dim=1)[0]  # [B, 2*hid]
        R_state = R_state.transpose(0, 1).contiguous().view(bsz, -1)

        ## [B, 12*hid]
        M_agr = torch.cat([C_mean, C_max, C_state, R_mean, R_max, R_state],
                          dim=-1)

        # ------------------ Output Layer ---------------------
        output = self.dropout(M_agr)
        output = self.mlps(output)
        output = self.output(output)
        return output