def _block_ngrams( self, ngram_size: int, logprobs: torch.Tensor, source: torch.LongTensor = None ): """ Hard block ngrams from the logprobs, based on the source. :param ngram_size: The length of ngrams to block. Must be > 0. :param logprobs: Float or HalfTensor, representing the log-probabilities. This is modified in place. :param source: Source text to grab ngrams from. If None, it uses the current hypothesis (i.e. self-blocking). """ for beam_id, hyp in enumerate(self.partial_hyps): if len(hyp) < ngram_size - 1: continue source_ = hyp if source is None else source ngrams = self._find_ngrams(source_, ngram_size) prefix = hyp[-(ngram_size - 1) :] for ngram in ngrams: if ngram_size == 1 or prefix == list(ngram[:-1]): logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs
def get_extra_output_from_mask( self, input: torch.LongTensor, encoder_output: torch.Tensor, encoder_mask: torch.Tensor, ) -> ExtraOutput: """ Use a trainable mask layer to determine which elements of the input to re-attend to. :param input: vectorized input tokens :param encoder_out: output encodings of input tokens :param encoder_mask: mask for input :return (enc_out, enc_mask): return the extra output to which we will be attending (for all layers). """ weights = self.softmax( self.mask_dropout(self.mask_linear(encoder_output)).masked_fill_( (encoder_mask == 0).view(*encoder_mask.size(), 1).expand(*encoder_output.size()), neginf(encoder_output.dtype), ), dim=1, ) topk = get_topk(self.opt, input.size(-1)) topk_inds = weights.sum(-1).topk(topk, dim=-1, sorted=False).indices new_input = torch.gather(input, dim=-1, index=topk_inds) out2 = super().forward(new_input) assert isinstance(out2, tuple) return (*out2, weights) # type: ignore
def forward(self, query_embs, in_mem_embs, out_mem_embs, pad_mask): """ Compute MemNN Hop step. :param query_embs: (bsz x esz) embedding of queries :param in_mem_embs: bsz list of (num_mems x esz) embedding of memories for activation :param out_mem_embs: bsz list of (num_mems x esz) embedding of memories for outputs :param pad_mask (bsz x num_mems) optional mask indicating which tokens correspond to padding :returns: (bsz x esz) output state """ # rotate query embeddings attn = torch.bmm(query_embs.unsqueeze(1), in_mem_embs).squeeze(1) if pad_mask is not None: attn[pad_mask] = neginf(attn.dtype) probs = self.softmax(attn) memory_output = torch.bmm(probs.unsqueeze(1), out_mem_embs).squeeze(1) output = memory_output + self.rotate(query_embs) return output
def forward(self, token_ids, segment_ids, attention_mask): """ Forward pass. """ output_bert, output_pooler = self.bert_model( token_ids, segment_ids, attention_mask ) # output_bert is a list of 12 (for bert base) layers. layer_of_interest = output_bert[self.layer_pulled] dtype = next(self.parameters()).dtype if self.add_transformer_layer: # Follow up by yet another transformer layer extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = (~extended_attention_mask).to(dtype) * neginf( dtype ) embedding_layer = self.additional_transformer_layer( layer_of_interest, extended_attention_mask ) else: embedding_layer = layer_of_interest if self.aggregation == "mean": # consider the average of all the output except CLS. # obviously ignores masked elements outputs_of_interest = embedding_layer[:, 1:, :] mask = attention_mask[:, 1:].type_as(embedding_layer).unsqueeze(2) sumed_embeddings = torch.sum(outputs_of_interest * mask, dim=1) nb_elems = torch.sum(attention_mask[:, 1:].type(dtype), dim=1).unsqueeze(1) embeddings = sumed_embeddings / nb_elems elif self.aggregation == "max": # consider the max of all the output except CLS outputs_of_interest = embedding_layer[:, 1:, :] mask = (~attention_mask[:, 1:]).type(dtype).unsqueeze(2) * neginf(dtype) embeddings, _ = torch.max(outputs_of_interest + mask, dim=1) else: # easiest, we consider the output of "CLS" as the embedding embeddings = embedding_layer[:, 0, :] # We need this in case of dimensionality reduction result = self.additional_linear_layer(embeddings) # Sort of hack to make it work with distributed: this way the pooler layer # is used for grad computation, even though it does not change anything... # in practice, it just adds a very (768*768) x (768*batchsize) matmul result += 0 * torch.sum(output_pooler) return result
def output(self, tensor): """ Compute output logits. """ # project back to vocabulary output = F.linear(tensor, self.embeddings.weight) # compatibility with fairseq: fairseq sometimes reuses BOS tokens and # we need to force their probability of generation to be 0. output[:, :, self.start_idx] = neginf(output.dtype) return output
def output_choose_knowledge(self, out_tokens): #outputと知識をsoftmaxして正解知識を選べるか # encode the context, pretty basic #N:バッチサイズ, K:知識数, T:時間, D:埋め込みサイズ, Tk: context_encoded, context_mask = self.transformer(out_tokens) # make all the knowledge into a 2D matrix to encode N, K, Tk = self.know_tokens.size() know_encoded, know_mask = self.transformer( self.know_tokens.reshape(-1, Tk)) # compute our sentence embeddings for context and knowledge context_use = universal_sentence_embedding(context_encoded, context_mask) know_use = universal_sentence_embedding(know_encoded, know_mask) # remash it back into the shape we need know_use = know_use.reshape(N, self.know_tokens.size(1), self.embed_dim) / np.sqrt(self.embed_dim) context_use /= np.sqrt(self.embed_dim) ck_attn = th.bmm(know_use, context_use.unsqueeze(-1)).squeeze(-1) # fill with near -inf #~はInvert-2^(N-1) to 2^(N-1)-1 ck_attn.masked_fill_(~self.ck_mask, neginf(context_encoded.dtype)) # pick the true chosen sentence. remember that TransformerEncoder outputs # (batch, time, embed) # but because know_encoded is a flattened, it's really # (N * K, T, D) # We need to compute the offsets of the chosen_sentences cs_encoded = None softmax_cs_weight = th.nn.functional.softmax( (ck_attn * self.knowledge_lamda), dim=1) """ #cs_idは0 softmax_cs_weightは(B,knowledge) true_ids_weight = th.zeros(softmax_cs_weight.shape, device=softmax_cs_weight.device, dtype=softmax_cs_weight.dtype) for temp in true_ids_weight: temp[0] = 1 loss = softmax_cs_weight - true_ids_weight loss = loss * loss loss[loss == 0] = 0.000001 loss = th.sqrt(loss) loss = th.sum(loss) / N #print(loss) self.know_tokens = None self.ck_mask = None self.cs_ids = None self.use_cs_ids = None # also return the knowledge selection mask for the loss """ return softmax_cs_weight
def forward(self, x, mask): x = self.linear(x) x = self.act(x) attn = self.attn_wei(x).squeeze(-1) attn.masked_fill_(~mask, neginf(x.dtype)) attn = self.softmax(attn) x = th.einsum('btd,bt->bd', x, attn) x = self.final(x) return x
def modify_logprobs(self, logprobs: torch.Tensor) -> torch.Tensor: """ Modify logprobs in PACER. The way it works: 1. With frequency r, select a token x_i+1 to re-rank. 2. Generate word probabilities for token x_i+1. 3. Examine top k words {x_j | score(x_j) in top_k(P(x_i+1 | x_0,...,x_i))}; use classifier to predict P(a|x1, ..., x_i, x_j) 4. Rescore top k words via multiplication, re-normalize, and advance the generation. :param logprobs: initial token probabilities :return modified: return the modified log probabilities according to PACER """ if random.random() > self.frequency: return logprobs vals, inds = logprobs.topk(self.n_toks, dim=-1, sorted=False) new_probs = logprobs.clone().fill_(neginf(logprobs.dtype)) # Construct partial hypotheses for each beam for each top K tokens batch_hyps = [ h for i in range(len(self.partial_hyps)) for h in [ self.agent._v2t(self.partial_hyps[i][1:] + [ind]) for ind in inds[i] ] ] # Classify all beam outputs predictor_outputs = self.classifier.batch_classify( [self.context_str] * self.n_toks * logprobs.size(0), batch_hyps ) # Extract RPA scores log_predictor_scores = ( torch.stack( [ F.log_softmax(pred['sorted_scores'].float(), dim=0)[ int(pred['text'] == self.character) - 1 ] for pred in predictor_outputs ] ) .to(vals.device) .view(vals.size()) ) # "Multiply" Probabilities (in log space...) scores = vals + log_predictor_scores for i in range(new_probs.size(0)): new_probs[i, inds[i]] = scores[i] return F.log_softmax(new_probs, dim=-1, dtype=torch.float32) # type: ignore
def forward(self, x, mask): # import ipdb; ipdb.set_trace() # x: N x T x D N, T, D = x.shape x = self.linear(x).view(N, T, self.out, D) x = self.act(x) attn = self.attn_wei(x).squeeze(-1) attn.masked_fill_(~mask[:, :, None], neginf(x.dtype)) attn = self.softmax(attn) x = th.einsum('btod,bto->bod', x, attn) x = self.proj(x) x = th.einsum('bod,vd->bov', x, self.embeddings) return x
def forward( self, xs: torch.Tensor, ys: torch.Tensor, mask_ys: Optional[torch.Tensor] = None, values: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Compute attention. Attend over ys with query xs to obtain weights, then apply weights to values (ys if yalues is None) Args: xs: B x query_len x dim (queries) ys: B x key_len x dim (keys) mask_ys: B x key_len (mask) values: B x value_len x dim (values); if None, default to ys """ bsz = xs.size(0) y_len = ys.size(1) x_len = xs.size(1) if self.attn == 'cosine': l1 = self.cosine(xs, ys).unsqueeze(self.dim - 1) else: l1 = torch.bmm(xs, ys.transpose(1, 2)) if self.attn == 'sqrt': d_k = ys.size(-1) l1 = l1 / math.sqrt(d_k) if mask_ys is not None: attn_mask = (mask_ys == 0).view(bsz, 1, y_len) attn_mask = attn_mask.repeat(1, x_len, 1) l1.masked_fill_(attn_mask, neginf(l1.dtype)) l2 = F.softmax(l1, dim=self.dim, dtype=torch.float).type_as(l1) if values is None: values = ys lhs_emb = torch.bmm(l2, values) # # add back the query if self.residual: lhs_emb = lhs_emb.add(xs) if self.get_weights: return lhs_emb.squeeze(self.dim - 1), l2 else: return lhs_emb.squeeze(self.dim - 1)
def forward(self, src_tokens, know_tokens, ck_mask, cs_ids, use_cs_ids): # encode the context, pretty basic context_encoded, context_mask = self.transformer(src_tokens) # make all the knowledge into a 2D matrix to encode N, K, Tk = know_tokens.size() know_flat = know_tokens.reshape(-1, Tk) know_encoded, know_mask = self.transformer(know_flat) # compute our sentence embeddings for context and knowledge context_use = universal_sentence_embedding(context_encoded, context_mask) know_use = universal_sentence_embedding(know_encoded, know_mask) # remash it back into the shape we need know_use = know_use.reshape(N, know_tokens.size(1), self.embed_dim) context_use /= np.sqrt(self.embed_dim) know_use /= np.sqrt(self.embed_dim) ck_attn = th.bmm(know_use, context_use.unsqueeze(-1)).squeeze(-1) # fill with near -inf ck_attn.masked_fill_(~ck_mask, neginf(context_encoded.dtype)) if not use_cs_ids: # if we're not given the true chosen_sentence (test time), pick our # best guess _, cs_ids = ck_attn.max(1) # pick the true chosen sentence. remember that TransformerEncoder outputs # (batch, time, embed) # but because know_encoded is a flattened, it's really # (N * K, T, D) # We need to compute the offsets of the chosen_sentences cs_offsets = th.arange(N, device=cs_ids.device) * K + cs_ids cs_encoded = know_encoded[cs_offsets] # but padding is (N * K, T) cs_mask = know_mask[cs_offsets] # finally, concatenate it all full_enc = th.cat([cs_encoded, context_encoded], dim=1) full_mask = th.cat([cs_mask, context_mask], dim=1) # also return the knowledge selection mask for the loss return full_enc, full_mask, ck_attn
def forward(self, input): """ Compute scores from inputs. :param input: (bsz x seq_len x num_directions * hiddensize) tensor of states, e.g. the output states of an RNN :returns: (bsz x seqlen x num_cands) scores for each candidate """ # next compute scores over dictionary if self.numsoftmax > 1: bsz = input.size(0) seqlen = input.size(1) if input.dim() > 1 else 1 # first compute different softmax scores based on input vec # hsz => numsoftmax * esz latent = self.latent(input) active = self.dropout(self.activation(latent)) # esz => num_features logit = F.linear(active.view(-1, self.esz), self.weight, self.bias) # calculate priors: distribution over which softmax scores to use # hsz => numsoftmax prior_logit = self.prior(input).view(-1, self.numsoftmax) # softmax over numsoftmax's prior = self.softmax(prior_logit) # now combine priors with logits prob = self.softmax(logit).view(bsz * seqlen, self.numsoftmax, -1) probs = (prob * prior.unsqueeze(2)).sum(1).view(bsz, seqlen, -1) scores = probs.log() else: # hsz => esz, good time for dropout e = self.dropout(self.o2e(input)) # esz => num_features scores = F.linear(e, self.weight, self.bias) if self.padding_idx >= 0: scores[:, :, self.padding_idx] = neginf(scores.dtype) return scores
def forward(self, src_tokens, know_tokens, ck_mask, res_tokens=None): # encode the context, pretty basic context_encoded, context_mask = self.transformer(src_tokens) # make all the knowledge into a 2D matrix to encode # knowledge is intent for customer and tickets for agent N, K, Tk = know_tokens.size() know_flat = know_tokens.reshape(-1, Tk) know_encoded, know_mask = self.knowledge_transformer(know_flat) if self.agenttype == 'customer': ck_attn = None intent_out = None name_out = None cs_encoded = know_encoded cs_mask = know_mask elif self.agenttype == 'agent': # import ipdb; ipdb.set_trace() # compute our sentence embeddings for context and knowledge context_use = universal_sentence_embedding(context_encoded, context_mask) know_use = universal_sentence_embedding(know_encoded, know_mask) # remash it back into the shape we need know_use = know_use.reshape(N, K, self.embed_dim) # project before calculate attn know_use_proj = self.know_use_project(know_use) ck_attn = th.bmm(know_use_proj, context_use.unsqueeze(-1)).squeeze(-1) ck_attn /= np.sqrt(self.embed_dim) # fill with near -inf ck_attn.masked_fill_(~ck_mask, neginf(context_encoded.dtype)) # Compute context knowledge attn prob ck_prob = nn.functional.softmax(ck_attn, dim=-1) _, cs_ids = ck_attn.max(1) # pick the true chosen sentence. remember that TransformerEncoder outputs # (batch, time, embed) # but because know_encoded is a flattened, it's really # (N * K, T, D) # We need to compute the offsets of the chosen_sentences cs_offsets = th.arange(N, device=cs_ids.device) * K + cs_ids cs_encoded = know_encoded[cs_offsets] # but padding is (N * K, T) cs_mask = know_mask[cs_offsets] # compute reservation embeddings res_encoded, res_mask = self.reservation_transformer(res_tokens) # finally, concatenate it all cs_encoded = th.cat([know_use, cs_encoded, res_encoded], dim=1) cs_mask = th.cat([ck_mask, cs_mask, res_mask], dim=1) # intent prediction intent_out = self.intent_head(context_encoded, context_mask) name_out = self.name_head(context_encoded, context_mask) # finally, concatenate it all full_enc = th.cat([cs_encoded, context_encoded], dim=1) full_mask = th.cat([cs_mask, context_mask], dim=1) # also return the knowledge selection mask for the loss return full_enc, full_mask, ck_attn, intent_out, name_out
def forward( # type: ignore # TODO: remove type ignore with pytorch 1.5: # https://github.com/pytorch/pytorch/pull/31057 self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None, mask: torch.Tensor = None, incr_state: Optional[Dict[str, torch.Tensor]] = None, static_kv: bool = False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """ Forward pass. :param query: attention query :param key: attention key :param value: attention value :param mask: tensor in which True means that we are allowing attention and False means we are blocking it. Mask is: - [B, key_len] (encoder self-attn and decoder enc/dec attn) - [B, query_len, key_len] (decoder self-attn) - [B, 1, key_len] (decoder self-attn with incr_state caching) :param incr_state: dictionary with values representing the previous states of the key, value, and mask :param static_kv: True if the key and value are held constant during decoding (as in encoder/decoder attention) :return: ( final attended tensor, new incremental state, key/value-multiplied tensor before softmax, ) """ batch_size, query_len, dim = query.size() assert ( dim == self.dim ), 'Dimensions do not match: {} query vs {} configured'.format(dim, self.dim) assert mask is not None, 'Mask is None, please specify a mask' n_heads = self.n_heads dim_per_head = dim // n_heads scale = math.sqrt(dim_per_head) def prepare_head(tensor): # input is [batch_size, seq_len, n_heads * dim_per_head] # output is [batch_size * n_heads, seq_len, dim_per_head] bsz, seq_len, _ = tensor.size() tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head) tensor = ( tensor.transpose(1, 2) .contiguous() .view(batch_size * n_heads, seq_len, dim_per_head) ) return tensor # q, k, v are the transformed values if key is None and value is None: # self attention key = value = query _, _key_len, dim = query.size() elif value is None: # key and value are the same, but query differs # self attention value = key assert key is not None # let mypy know we sorted this _, _key_len, dim = key.size() q = prepare_head(self.q_lin(query)) k = prepare_head(self.k_lin(key)) v = prepare_head(self.v_lin(value)) # Prepend incremental states. For each of the key, value, and mask, see if # a previous incremental state exists, and if so, reshape it to match the shape # of the new state. Concatenate the previous and new states to match what the # full state would have been if we had not cached. (If we are using static_kv, # these three states are unchanging, so just re-use the cached states.) if incr_state is None: incr_state = {} if 'prev_key' in incr_state: prev_key = incr_state['prev_key'].view( batch_size * n_heads, -1, dim_per_head ) if static_kv: k = prev_key else: k = torch.cat([prev_key, k], dim=1) if 'prev_value' in incr_state: prev_value = incr_state['prev_value'].view( batch_size * n_heads, -1, dim_per_head ) if static_kv: v = prev_value else: v = torch.cat([prev_value, v], dim=1) if 'prev_mask' in incr_state: if static_kv: mask = incr_state['prev_mask'] else: mask = torch.cat([incr_state['prev_mask'], mask], dim=2) # Prepend along the key_len dimension (analogous to # incr_state['prev_key']) # Save new incremental states. We reshape to allow for reordering along batch # dimension. new_incr_state = { 'prev_key': k.view(batch_size, n_heads, -1, dim_per_head), 'prev_value': v.view(batch_size, n_heads, -1, dim_per_head), 'prev_mask': mask, } full_key_len = k.size(1) dot_prod = q.div_(scale).bmm(k.transpose(1, 2)) # [B * n_heads, query_len, key_len] attn_mask = ( (mask == 0) .view(batch_size, 1, -1, full_key_len) .repeat(1, n_heads, 1, 1) .expand(batch_size, n_heads, query_len, full_key_len) .view(batch_size * n_heads, query_len, full_key_len) ) assert attn_mask.shape == dot_prod.shape dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype)) attn_weights = F.softmax( dot_prod, dim=-1, dtype=torch.float # type: ignore ).type_as(query) attn_weights = self.attn_dropout(attn_weights) # --attention-dropout attentioned = attn_weights.bmm(v) attentioned = ( attentioned.type_as(query) .view(batch_size, n_heads, query_len, dim_per_head) .transpose(1, 2) .contiguous() .view(batch_size, query_len, dim) ) out = self.out_lin(attentioned) return out, new_incr_state, dot_prod
def _generate( self, batch: Batch, beam_size: int, max_ts: int, prefix_tokens: tp.Optional[torch.LongTensor] = None, ): """ Generate an output with beam search. Depending on the options, this may perform greedy/topk/nucleus generation. :param Batch batch: Batch structure with input and labels :param int beam_size: Size of each beam during the search :param int max_ts: the maximum length of the decoded sequence :param prefix_tokens: if given, a tensor of tokens that must begin the decoded sequence. :return: tuple (beam_pred_scores, beams) - beam_preds_scores: list of (prediction, score) pairs for each sample in Batch - beams :list of Beam instances defined in Beam class, can be used for any following postprocessing, e.g. dot logging. """ model = self.model if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = self.model.module encoder_states = model.encoder(*self._encoder_input(batch)) if batch.text_vec is not None: dev = batch.text_vec.device else: assert batch.label_vec is not None, "need label_vec for _generate" dev = batch.label_vec.device bsz = ( len(batch.text_lengths) if batch.text_lengths is not None else len( batch.image) # type: ignore ) if batch.text_vec is not None: batchsize = batch.text_vec.size(0) beams = [ self._treesearch_factory(dev).set_context( self._get_context(batch, batch_idx)).set_block_list( self.beam_block_list) for batch_idx in range(batchsize) ] else: beams = [self._treesearch_factory(dev) for _ in range(bsz)] # repeat encoder outputs and decoder inputs decoder_input = self._get_initial_decoder_input(bsz, beam_size, dev) inds = torch.arange(bsz).to(dev).unsqueeze(1).repeat( 1, beam_size).view(-1) encoder_states = model.reorder_encoder_states(encoder_states, inds) incr_state = None for _ts in range(max_ts): if all((b.is_done() for b in beams)): # exit early if possible break score, incr_state = model.decoder(decoder_input, encoder_states, incr_state) # only need the final hidden state to make the word prediction score = score[:, -1:, :] score = model.output(score) # score contains softmax scores for bsz * beam_size samples score = score.view(bsz, beam_size, -1) if self.temperature != 1.0: score.div_(self.temperature) # force to fp32 to avoid overflow issues during search calculations score = F.log_softmax(score, dim=-1, dtype=torch.float32) # type: ignore if prefix_tokens is not None and _ts < prefix_tokens.size(1): # generate prefix_tokens for every timestep that they exist # achieve by setting score of all other tokens to be -inf prefix_toks = prefix_tokens[:, _ts].unsqueeze(-1).repeat( 1, beam_size) prefix_score = score.gather(-1, prefix_toks.unsqueeze(-1)) prefix_mask = prefix_toks.ne(self.NULL_IDX) score[prefix_mask] = neginf(score.dtype) score[prefix_mask] = score[prefix_mask].scatter_( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_score[prefix_mask], ) for i, b in enumerate(beams): if not b.is_done(): score_in = score[i] score_in += self._nidf_feats.to(dev) b.advance(score_in) incr_state_inds = torch.cat([ beam_size * i + b.get_backtrack_from_current_step() for i, b in enumerate(beams) ]) incr_state = model.reorder_decoder_incremental_state( incr_state, incr_state_inds) selection = torch.cat([ b.get_output_from_current_step() for b in beams ]).unsqueeze(-1) decoder_input = self._get_next_decoder_input( decoder_input, selection, incr_state_inds) # get all finalized candidates for each sample (and validate them) n_best_beam_preds_scores = [b.get_rescored_finished() for b in beams] if hasattr(self, '_rerank_beams'): n_best_beam_preds_scores = self._rerank_beams( # type: ignore batch, n_best_beam_preds_scores) # get the top prediction for each beam (i.e. minibatch sample) beam_preds_scores = [ n_best_list[0] for n_best_list in n_best_beam_preds_scores ] return beam_preds_scores, beams
def advance(self, logprobs): """ Advance the beam one step. """ current_length = len(self.all_scores) - 1 if current_length < self.min_length: # penalize all eos probs to make it decode longer for hyp_id in range(logprobs.size(0)): logprobs[hyp_id][self.eos] = neginf(logprobs.dtype) if self.scores is None: self.scores = torch.zeros(1).type_as(logprobs).to(logprobs.device) # penalize hypotheses ending in EOS on the prior scores (self.scores) level # this is related to search which uses prior scores (self.scores) (e.g. beam) for hyp_id, token in enumerate(self.outputs[-1]): if token == self.eos: self.scores[hyp_id] = neginf(self.scores.dtype) # beam blocking if self.block_ngram > 0: logprobs = self._block_ngrams(self.block_ngram, logprobs, None) if self.context_block_ngram > 0: if self.context is None: raise ValueError( "Must use TreeSearch.set_context to use context blocking." ) logprobs = self._block_ngrams( self.context_block_ngram, logprobs, self.context ) hyp_ids, tok_ids, self.scores = self.select_paths(logprobs, self.scores) # use clone() here to ensure that self.all_scores will not be changed # later due to any penalties to self.scores self.all_scores.append(self.scores.clone()) self.outputs.append(tok_ids) self.bookkeep.append(hyp_ids) self.partial_hyps = [ self.partial_hyps[hyp_ids[i]] + [tok_ids[i].item()] for i in range(self.beam_size) ] # check new hypos for eos label, if we have some, add to finished for hypid in range(self.beam_size): if self.outputs[-1][hypid] == self.eos: if self.scores[hypid] == neginf(self.scores.dtype): continue # this is finished hypo, adding to finished eostail = _HypothesisTail( timestep=len(self.outputs) - 1, hypid=hypid, score=self.all_scores[-1][hypid], tokenid=self.eos, ) self.finished.append(eostail) self.n_best_counter += 1 if self.outputs[-1][0] == self.eos: self.eos_top = True if self.eos_top_ts is None: self.eos_top_ts = len(self.outputs) - 1
def forward(self, xes, hidden, attn_params): """ Compute attention over attn_params given input and hidden states. :param xes: input state. will be combined with applied attention. :param hidden: hidden state from model. will be used to select states to attend to in from the attn_params. :param attn_params: tuple of encoder output states and a mask showing which input indices are nonzero. :returns: output, attn_weights output is a new state of same size as input state `xes`. attn_weights are the weights given to each state in the encoder outputs. """ if self.attention == 'none': # do nothing, no attention return xes, None if type(hidden) == tuple: # for lstms use the "hidden" state not the cell state hidden = hidden[0] last_hidden = hidden[-1] # select hidden state from last RNN layer enc_out, attn_mask = attn_params bsz, seqlen, hszXnumdir = enc_out.size() numlayersXnumdir = last_hidden.size(1) if self.attention == 'local': # local attention weights aren't based on encoder states h_merged = torch.cat((xes.squeeze(1), last_hidden), 1) attn_weights = F.softmax(self.attn(h_merged), dim=1) # adjust state sizes to the fixed window size if seqlen > self.max_length: offset = seqlen - self.max_length enc_out = enc_out.narrow(1, offset, self.max_length) seqlen = self.max_length if attn_weights.size(1) > seqlen: attn_weights = attn_weights.narrow(1, 0, seqlen) else: hid = last_hidden.unsqueeze(1) if self.attention == 'concat': # concat hidden state and encoder outputs hid = hid.expand(bsz, seqlen, numlayersXnumdir) h_merged = torch.cat((enc_out, hid), 2) # then do linear combination of them with activation active = F.tanh(self.attn(h_merged)) attn_w_premask = self.attn_v(active).squeeze(2) elif self.attention == 'dot': # dot product between hidden and encoder outputs if numlayersXnumdir != hszXnumdir: # enc_out has two directions, so double hid hid = torch.cat([hid, hid], 2) enc_t = enc_out.transpose(1, 2) attn_w_premask = torch.bmm(hid, enc_t).squeeze(1) elif self.attention == 'general': # before doing dot product, transform hidden state with linear # same as dot if linear is identity hid = self.attn(hid) enc_t = enc_out.transpose(1, 2) attn_w_premask = torch.bmm(hid, enc_t).squeeze(1) # calculate activation scores, apply mask if needed if attn_mask is not None: # remove activation from NULL symbols attn_w_premask.masked_fill_((~attn_mask), neginf(attn_w_premask.dtype)) attn_weights = F.softmax(attn_w_premask, dim=1) # apply the attention weights to the encoder states attn_applied = torch.bmm(attn_weights.unsqueeze(1), enc_out) # concatenate the input and encoder states merged = torch.cat((xes.squeeze(1), attn_applied.squeeze(1)), 1) # combine them with a linear layer and tanh activation output = torch.tanh(self.attn_combine(merged).unsqueeze(1)) return output, attn_weights
def forward(self, src_tokens, know_tokens, ck_mask, cs_ids, use_cs_ids): # encode the context, pretty basic #N:バッチサイズ, K:知識数, T:時間, D:埋め込みサイズ, Tk: #src_tokens torch.Size([B, T]) #cs_ids tensor([0, 0, 0, 0], device='cuda:0') #use_cs_ids trainならTrue self.know_tokens = know_tokens self.ck_mask = ck_mask self.cs_ids = cs_ids self.use_cs_ids = use_cs_ids context_encoded, context_mask = self.transformer(src_tokens) # make all the knowledge into a 2D matrix to encode N, K, Tk = know_tokens.size() know_encoded, know_mask = self.transformer(know_tokens.reshape(-1, Tk)) # compute our sentence embeddings for context and knowledge context_use = universal_sentence_embedding(context_encoded, context_mask) know_use = universal_sentence_embedding(know_encoded, know_mask) # remash it back into the shape we need know_use = know_use.reshape(N, know_tokens.size(1), self.embed_dim) / np.sqrt(self.embed_dim) context_use /= np.sqrt(self.embed_dim) ck_attn = th.bmm(know_use, context_use.unsqueeze(-1)).squeeze(-1) # fill with near -inf ck_attn.masked_fill_(~ck_mask, neginf(context_encoded.dtype)) if self.soft_attention: # pick the true chosen sentence. remember that TransformerEncoder outputs # (batch, time, embed) # but because know_encoded is a flattened, it's really # (N * K, T, D) # We need to compute the offsets of the chosen_sentences cs_encoded = None softmax_cs_weight = th.nn.functional.softmax( (ck_attn * self.knowledge_lamda), dim=1) #add true_ids_weight = th.zeros(softmax_cs_weight.shape, device=softmax_cs_weight.device, dtype=softmax_cs_weight.dtype) for temp in true_ids_weight: temp[0] = 1 weight_abs = th.abs(softmax_cs_weight - true_ids_weight) weight_abs *= weight_abs _, T, D = know_encoded.size() # finally, concatenate it all full_enc = th.cat([(know_encoded.reshape( (N * K, -1)) * th.nn.functional.softmax( (ck_attn * self.knowledge_lamda), dim=1).reshape( -1, 1).expand(N * K, T * D)).reshape( (N, K, T, D)).sum(dim=1), context_encoded], dim=1) full_mask = th.cat([ know_mask[th.arange(N, device=cs_ids.device) * K], context_mask ], dim=1) # also return the knowledge selection mask for the loss return full_enc, full_mask, ck_attn else: if not use_cs_ids: # if we're not given the true chosen_sentence (test time), pick our # best guess # cs_idsが使われるやつ _, cs_ids = ck_attn.max(1) #_, cs_ids = self.second_max(ck_attn, 1) # pick the true chosen sentence. remember that TransformerEncoder outputs # (batch, time, embed) # but because know_encoded is a flattened, it's really # (N * K, T, D) # We need to compute the offsets of the chosen_sentences cs_offsets = th.arange(N, device=cs_ids.device) * K + cs_ids cs_encoded = know_encoded[cs_offsets] # but padding is (N * K, T) cs_mask = know_mask[cs_offsets] # finally, concatenate it all full_enc = th.cat([cs_encoded, context_encoded], dim=1) full_mask = th.cat([cs_mask, context_mask], dim=1) # also return the knowledge selection mask for the loss return full_enc, full_mask, ck_attn
def test_neginf(self): assert neginf(torch.float32) < -1e15 assert neginf(torch.float16) > -1e15 assert neginf(torch.float16) < -1e4
def forward(self, query, key=None, value=None, mask=None): """ Forward pass. """ # TODO: there are a lot of parameters to document here. # Input is [B, query_len, dim] # Mask is [B, key_len] (selfattn) or [B, key_len, key_len] (enc attn) batch_size, query_len, dim = query.size() assert (dim == self.dim ), 'Dimensions do not match: {} query vs {} configured'.format( dim, self.dim) assert mask is not None, 'Mask is None, please specify a mask' n_heads = self.n_heads dim_per_head = dim // n_heads scale = math.sqrt(dim_per_head) def prepare_head(tensor): # input is [batch_size, seq_len, n_heads * dim_per_head] # output is [batch_size * n_heads, seq_len, dim_per_head] bsz, seq_len, _ = tensor.size() tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head) tensor = (tensor.transpose(1, 2).contiguous().view( batch_size * n_heads, seq_len, dim_per_head)) return tensor # q, k, v are the transformed values if key is None and value is None: # self attention key = value = query elif value is None: # key and value are the same, but query differs # self attention value = key _, key_len, dim = key.size() q = prepare_head(self.q_lin(query)) k = prepare_head(self.k_lin(key)) v = prepare_head(self.v_lin(value)) dot_prod = q.div_(scale).bmm(k.transpose(1, 2)) # [B * n_heads, query_len, key_len] attn_mask = ((mask == 0).view(batch_size, 1, -1, key_len).repeat( 1, n_heads, 1, 1).expand(batch_size, n_heads, query_len, key_len).view(batch_size * n_heads, query_len, key_len)) assert attn_mask.shape == dot_prod.shape dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype)) attn_weights = F.softmax(dot_prod, dim=-1).type_as(query) attn_weights = self.attn_dropout(attn_weights) # --attention-dropout attentioned = attn_weights.bmm(v) attentioned = (attentioned.type_as(query).view( batch_size, n_heads, query_len, dim_per_head).transpose( 1, 2).contiguous().view(batch_size, query_len, dim)) out = self.out_lin(attentioned) return out