def forward(self, input, encoder_state, incr_state=None): attention_mask = None if incr_state is None: # first step if (not self.add_start_token and input.size(1) == 1 and int(input[0][0]) == self.start_idx): # generating: ignore the start token model_input = encoder_state else: # forced decoding: concatenate the context # with the labels model_input, _ = concat_without_padding(encoder_state, input, use_cuda=self.use_cuda, null_idx=self.null_idx) attention_mask = model_input != self.null_idx else: # generation: get the last token input model_input = input[:, -1].unsqueeze(1) transformer_outputs = self.transformer(model_input, past=incr_state, attention_mask=attention_mask) hidden_states = transformer_outputs[0] new_incr_state = transformer_outputs[1] return hidden_states, new_incr_state
def score_candidates(self, batch, cand_vecs, cand_encs=None): if cand_encs is not None: raise Exception('Candidate pre-computation is impossible on the ' 'crossencoder') num_cands_per_sample = cand_vecs.size(1) bsz = cand_vecs.size(0) text_idx = (batch.text_vec.unsqueeze(1).expand( -1, num_cands_per_sample, -1).contiguous().view(num_cands_per_sample * bsz, -1)) cand_idx = cand_vecs.view(num_cands_per_sample * bsz, -1) tokens, segments = concat_without_padding(text_idx, cand_idx, self.use_cuda, self.NULL_IDX) scores = self.model(tokens, segments) scores = scores.view(bsz, num_cands_per_sample) return scores
def score_candidates(self, batch, cand_vecs, cand_encs=None): # concatenate text and candidates (not so easy) # unpad and break nb_cands = cand_vecs.size()[1] size_batch = cand_vecs.size()[0] text_vec = batch.text_vec tokens_context = (text_vec.unsqueeze(1).expand( -1, nb_cands, -1).contiguous().view(nb_cands * size_batch, -1)) # remove the start token ["CLS"] from candidates tokens_cands = cand_vecs.view(nb_cands * size_batch, -1) all_tokens, all_segments = concat_without_padding( tokens_context, tokens_cands, self.use_cuda, self.NULL_IDX) all_mask = all_tokens != self.NULL_IDX all_tokens *= all_mask.long() scores = self.model(all_tokens, all_segments, all_mask) return scores.view(size_batch, nb_cands)