コード例 #1
0
    def forward(self, input, encoder_state, incr_state=None):
        attention_mask = None
        if incr_state is None:
            # first step
            if (not self.add_start_token and input.size(1) == 1
                    and int(input[0][0]) == self.start_idx):
                # generating: ignore the start token
                model_input = encoder_state
            else:
                # forced decoding: concatenate the context
                # with the labels
                model_input, _ = concat_without_padding(encoder_state,
                                                        input,
                                                        use_cuda=self.use_cuda,
                                                        null_idx=self.null_idx)
                attention_mask = model_input != self.null_idx
        else:
            # generation: get the last token input
            model_input = input[:, -1].unsqueeze(1)

        transformer_outputs = self.transformer(model_input,
                                               past=incr_state,
                                               attention_mask=attention_mask)
        hidden_states = transformer_outputs[0]
        new_incr_state = transformer_outputs[1]

        return hidden_states, new_incr_state
コード例 #2
0
 def score_candidates(self, batch, cand_vecs, cand_encs=None):
     if cand_encs is not None:
         raise Exception('Candidate pre-computation is impossible on the '
                         'crossencoder')
     num_cands_per_sample = cand_vecs.size(1)
     bsz = cand_vecs.size(0)
     text_idx = (batch.text_vec.unsqueeze(1).expand(
         -1, num_cands_per_sample,
         -1).contiguous().view(num_cands_per_sample * bsz, -1))
     cand_idx = cand_vecs.view(num_cands_per_sample * bsz, -1)
     tokens, segments = concat_without_padding(text_idx, cand_idx,
                                               self.use_cuda, self.NULL_IDX)
     scores = self.model(tokens, segments)
     scores = scores.view(bsz, num_cands_per_sample)
     return scores
コード例 #3
0
    def score_candidates(self, batch, cand_vecs, cand_encs=None):
        # concatenate text and candidates (not so easy)
        # unpad and break
        nb_cands = cand_vecs.size()[1]
        size_batch = cand_vecs.size()[0]
        text_vec = batch.text_vec

        tokens_context = (text_vec.unsqueeze(1).expand(
            -1, nb_cands, -1).contiguous().view(nb_cands * size_batch, -1))

        # remove the start token ["CLS"] from candidates
        tokens_cands = cand_vecs.view(nb_cands * size_batch, -1)
        all_tokens, all_segments = concat_without_padding(
            tokens_context, tokens_cands, self.use_cuda, self.NULL_IDX)
        all_mask = all_tokens != self.NULL_IDX
        all_tokens *= all_mask.long()
        scores = self.model(all_tokens, all_segments, all_mask)
        return scores.view(size_batch, nb_cands)