Ejemplo n.º 1
0
    def embed_context(self, context_input_ids, context_segment_ids,
                      context_input_masks):
        batch_size, res_cnt = context_input_ids.shape

        ## poly context encoder
        if isinstance(self.bert, DistilBertModel):
            state_vecs = self.bert(
                context_input_ids,
                context_input_masks)[-1]  # [bs, length, dim]
        else:
            state_vecs = self.bert(context_input_ids, context_input_masks,
                                   context_segment_ids)[0]  # [bs, length, dim]
        poly_code_ids = torch.arange(self.poly_m,
                                     dtype=torch.long,
                                     device=context_input_ids.device)
        poly_code_ids += 1
        poly_code_ids = poly_code_ids.unsqueeze(0).expand(
            batch_size, self.poly_m)
        poly_codes = self.poly_code_embeddings(poly_code_ids)
        context_vecs = dot_attention(poly_codes, state_vecs, state_vecs,
                                     context_input_masks, self.dropout)

        ## 这里先norm一下,相当于以某种方式得到了context_vec和response_vec
        context_vecs = self.context_fc(self.dropout(context_vecs))
        context_vecs = F.normalize(context_vecs, 2, -1)  # [bs, m, dim]

        ## poly final context vector aggregation
        return context_vecs
Ejemplo n.º 2
0
    def embed_response(self, responses_input_ids, responses_segment_ids,
                       responses_input_masks):
        batch_size, res_cnt, seq_length = responses_input_ids.shape

        ## response encoder
        responses_input_ids = responses_input_ids.view(-1, seq_length)
        responses_input_masks = responses_input_masks.view(-1, seq_length)
        responses_segment_ids = responses_segment_ids.view(-1, seq_length)
        if isinstance(self.bert, DistilBertModel):
            state_vecs = self.bert(
                responses_input_ids,
                responses_input_masks)[-1]  # [bs, length, dim]
        else:
            state_vecs = self.bert(
                responses_input_ids, responses_input_masks,
                responses_segment_ids)[0]  # [bs, length, dim]
        poly_code_ids = torch.zeros(batch_size * res_cnt,
                                    1,
                                    dtype=torch.long,
                                    device=responses_input_ids.device)
        poly_codes = self.poly_code_embeddings(poly_code_ids)
        responses_vec = dot_attention(poly_codes, state_vecs, state_vecs,
                                      responses_input_masks, self.dropout)
        responses_vec = responses_vec.view(batch_size, res_cnt, -1)

        responses_vec = self.response_fc(self.dropout(responses_vec))
        responses_vec = F.normalize(responses_vec, 2, -1)

        return responses_vec
Ejemplo n.º 3
0
    def forward(self,
                context_input_ids,
                context_segment_ids,
                context_input_masks,
                responses_input_ids,
                responses_segment_ids,
                responses_input_masks,
                labels=None):
        ## only select the first response (whose lbl==1)
        if labels is not None:
            responses_input_ids = responses_input_ids[:, 0, :].unsqueeze(1)
            responses_segment_ids = responses_segment_ids[:, 0, :].unsqueeze(1)
            responses_input_masks = responses_input_masks[:, 0, :].unsqueeze(1)
        batch_size, res_cnt, seq_length = responses_input_ids.shape

        ## poly context encoder
        if isinstance(self.bert, DistilBertModel):
            state_vecs = self.bert(
                context_input_ids,
                context_input_masks)[-1]  # [bs, length, dim]
        else:
            state_vecs = self.bert(context_input_ids, context_input_masks,
                                   context_segment_ids)[0]  # [bs, length, dim]
        poly_code_ids = torch.arange(self.poly_m,
                                     dtype=torch.long,
                                     device=context_input_ids.device)
        poly_code_ids += 1
        poly_code_ids = poly_code_ids.unsqueeze(0).expand(
            batch_size, self.poly_m)
        poly_codes = self.poly_code_embeddings(poly_code_ids)
        context_vecs = dot_attention(poly_codes, state_vecs, state_vecs,
                                     context_input_masks, self.dropout)

        ## response encoder
        responses_input_ids = responses_input_ids.view(-1, seq_length)
        responses_input_masks = responses_input_masks.view(-1, seq_length)
        responses_segment_ids = responses_segment_ids.view(-1, seq_length)
        if isinstance(self.bert, DistilBertModel):
            state_vecs = self.bert(
                responses_input_ids,
                responses_input_masks)[-1]  # [bs, length, dim]
        else:
            state_vecs = self.bert(
                responses_input_ids, responses_input_masks,
                responses_segment_ids)[0]  # [bs, length, dim]
        poly_code_ids = torch.zeros(batch_size * res_cnt,
                                    1,
                                    dtype=torch.long,
                                    device=context_input_ids.device)
        poly_codes = self.poly_code_embeddings(poly_code_ids)
        responses_vec = dot_attention(poly_codes, state_vecs, state_vecs,
                                      responses_input_masks, self.dropout)
        responses_vec = responses_vec.view(batch_size, res_cnt, -1)

        ## 这里先norm一下,相当于以某种方式得到了context_vec和response_vec
        context_vecs = self.context_fc(self.dropout(context_vecs))
        context_vecs = F.normalize(context_vecs, 2, -1)  # [bs, m, dim]
        responses_vec = self.response_fc(self.dropout(responses_vec))
        responses_vec = F.normalize(responses_vec, 2, -1)

        ## poly final context vector aggregation
        if labels is not None:
            responses_vec = responses_vec.view(1, batch_size, -1).expand(
                batch_size, batch_size, self.vec_dim)
        final_context_vec = dot_attention(responses_vec, context_vecs,
                                          context_vecs, None, self.dropout)
        final_context_vec = F.normalize(
            final_context_vec, 2,
            -1)  # [bs, res_cnt, dim], res_cnt==bs when training

        dot_product = torch.sum(final_context_vec * responses_vec,
                                -1)  # [bs, res_cnt], res_cnt==bs when training
        if labels is not None:
            mask = torch.eye(context_input_ids.size(0)).to(
                context_input_ids.device)
            loss = F.log_softmax(dot_product * 5, dim=-1) * mask
            loss = (-loss.sum(dim=1)).mean()

            return loss
        else:
            cos_similarity = (dot_product + 1) / 2
            return cos_similarity