def forward(self, inputs): words, masks, pos, deprel, head, subj_pos, obj_pos = inputs # unpack src_mask = (words != constant.PAD_ID).unsqueeze(-2) word_embs = self.emb(words) embs = [word_embs] if self.opt['pos_dim'] > 0: embs += [self.pos_emb(pos)] embs = torch.cat(embs, dim=2) embs = self.in_drop(embs) if self.opt.get('rnn', False): embs = self.input_W_R(embs) gcn_inputs = self.rnn_drop( self.encode_with_rnn(embs, masks, words.size()[0])) else: gcn_inputs = embs gcn_inputs = self.input_W_G(gcn_inputs) layer_list = [] outputs = gcn_inputs adj_list = None for i in range(len(self.layers)): if i == 0 or i == 3: adj_list = self.layers[i](outputs, src_mask) if self.opt['data_dir'] != 'dataset/semeval': for j in range(len(adj_list)): if i == 3: adj_list[j] = entmax_bisect( adj_list[j], self.alpha_list[self.heads + j]) else: adj_list[j] = entmax_bisect( adj_list[j], self.alpha_list[j]) else: outputs = self.layers[i](adj_list, outputs) layer_list.append(outputs) aggregate_out = torch.cat(layer_list, dim=2) dcgcn_output = self.aggregate_W(aggregate_out) adj = torch.stack(adj_list, dim=1).sum(dim=1) mask = (adj.sum(2) + adj.sum(1)).eq(0).unsqueeze(2) return dcgcn_output, mask
def entmax(input_ids, tokenizer, model, prompt, epoch=None, alpha=1.5, max_length=50): new_input_ids = deepcopy(input_ids) alpha = torch.tensor(alpha, requires_grad=True) log = [] # print(input_ids) for _ in range(max_length): prediction_scores = model(new_input_ids)[0][0][-1] prediction_prob = entmax_bisect(prediction_scores, alpha) candidates = torch.nonzero(prediction_prob) next_token_id = candidates[torch.randint(candidates.size()[0], (1,))] # print(tokenizer.decode(new_input_ids[0], skip_special_tokens=False)) # print(tokenizer.decode(next_token_id[0], skip_special_tokens=False),':\t', prediction_prob[next_token_id].data[0][0]) new_input_ids = torch.cat((new_input_ids, next_token_id), dim=1) log.append((tokenizer.decode(next_token_id[0], skip_special_tokens=False), prediction_prob[next_token_id].item())) # pprint(log) output_sent = tokenizer.decode(new_input_ids[0], skip_special_tokens=False) # if epoch is not None: # prompt = f'epoch{epoch}_{prompt}' draw_prob_graph(log, text=output_sent, filename=prompt, title=f'GPT entmax epoch{epoch}') return output_sent
def attention(self, query, key, value, mask=None, dropout=None): "Compute 'Scaled Dot Product Attention'" d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) \ / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = entmax_bisect(scores, alpha=self.alpha, dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn
def __init__( self, plate_name: str, sampling_method: Optional[storch.sampling.SamplingMethod] = None, alpha: float = 1.5, adaptive=False, n_samples: int = 1, straight_through=False, initial_temperature=1.0, min_temperature=1.0e-4, annealing_rate=0.0, ): if not sampling_method: sampling_method = storch.sampling.MonteCarlo(plate_name, n_samples) super().__init__( plate_name, sampling_method.set_mc_sample(self.sample_gumbel_entmax), ) self.adaptive = adaptive self.straight_through = straight_through self.register_buffer("temperature", torch.tensor(initial_temperature)) self.register_buffer("annealing_rate", torch.tensor(annealing_rate)) self.register_buffer("min_temperature", torch.tensor(min_temperature)) self.alpha = alpha if adaptive: self.alpha = torch.nn.Parameter( torch.tensor(self.alpha, requires_grad=True) ) if not adaptive and alpha == 1.5: self.entmax = entmax.entmax15 elif not adaptive and alpha == 2.0: self.entmax = entmax.sparsemax else: if adaptive: self.entmax = lambda x: entmax.entmax_bisect( x, torch.nn.functional.softplus(self.alpha - 1) + 1 ) else: self.entmax = lambda x: entmax.entmax_bisect(x, self.alpha)
def forward(self, x): b, t, e = x.size() h = self.heads assert e == self.emb, f'Input embedding dim ({e}) should match layer embedding dim ({self.emb})' s = e // h x = x.view(b, t, h, s) keys = self.tokeys(x) queries = self.toqueries(x) values = self.tovalues(x) assert keys.size() == (b, t, h, s) assert queries.size() == (b, t, h, s) assert values.size() == (b, t, h, s) # Compute scaled dot-product self-attention # - fold heads into the batch dimension keys = keys.transpose(1, 2).contiguous().view(b * h, t, s) queries = queries.transpose(1, 2).contiguous().view(b * h, t, s) values = values.transpose(1, 2).contiguous().view(b * h, t, s) queries = queries / (e**(1 / 4)) keys = keys / (e**(1 / 4)) # - Instead of dividing the dot products by sqrt(e), we scale the keys and values. # This should be more memory efficient # - get dot product of queries and keys, and scale dot = torch.bmm(queries, keys.transpose(1, 2)) assert dot.size() == (b * h, t, t) if self.mask: # mask out the upper half of the dot matrix, excluding the diagonal mask_(dot, maskval=float('-inf'), mask_diagonal=False) # dot = F.softmax(dot, dim=-1) # dot = sparsemax(dot, dim=-1) dot = entmax_bisect(dot, alpha=self.alpha, dim=-1) # - dot now has row-wise self-attention probabilities # apply the self attention to the values out = torch.bmm(dot, values).view(b, h, t, s) # swap h, t back, unify heads out = out.transpose(1, 2).contiguous().view(b, t, s * h) return self.unifyheads(out)
def forward(self, scores: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: """Map a score vector to a probability distribution akin to softmax (alpha=1) and sparsemax (alpha=2) Args: scores (torch.Tensor): (Batch x Sequence Length) Attention scores (also referred to as weights) mask (torch.BoolTensor): (Batch x Sequence Length) Specifies which indices are just padding Returns: torch.Tensor: Distribution resulting from entmax with specified alpha """ # Entmax is only defined for alpha > 1 self.alpha.data = torch.clamp(self.alpha.data, min=1.001) masked_scores = replace_masked_values(scores, mask, -float("inf")) return entmax_bisect(masked_scores, self.alpha, dim=-1)
def alpha_entmax_loss(model, batch, args): longer_sample = batch[0].to(args.gpu) inp = longer_sample[:, :args.train_batch_size] model_output = model(input_ids=inp) target = longer_sample[:, 1:args.train_batch_size + 1] logits = model_output[0] alpha = torch.tensor([args.alpha], requires_grad=True, device=torch.device(args.gpu)) probs = entmax_bisect(logits, alpha) loss = ((probs - F.one_hot(target, num_classes=probs.size(-1))) * logits).sum(-1) loss += alpha_entropy(probs, args.alpha) loss = loss.sum() true_token_logits = -F.nll_loss(logits[0], target[0], reduction='none') ntokens = inp.numel() arange = np.arange(probs.size(1)) next_token_probs = probs[:, arange, target.squeeze().tolist()] voc_sizes = probs.size(-1) smoothed_nll = -torch.mean( torch.log((next_token_probs + args.laplas_eps) / (1 + args.laplas_eps * voc_sizes))) logging_output = TrainingMetrics.ranking_metrics(logits[0].float(), true_token_logits, None, ntokens, target[0]) logging_output['loss'] = loss.item() logging_output['smoothed_nll_loss'] = smoothed_nll.item() logging_output['normalizer'] = ntokens logging_output['sample_size'] = ntokens logging_output['ntokens'] = ntokens logging_output['js_div'] = jensen_shannon_divergence(probs, target).mean().item() print(logging_output['js_div']) loss = loss / ntokens return loss, logging_output
def forward(self, hidden, orig_prob, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words. Args: hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` attn (FloatTensor): attn for each ``(batch x tlen, input_size)`` src_map (FloatTensor): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. ``(src_len, batch, extra_words)`` """ # CHECKS # batch_by_tlen, _ = hidden.size() # batch_by_tlen_, slen = attn.size() onehot_src_map = \ F.one_hot(src_map.long(), torch.max(src_map).long() + 1) batch, slen, cvocab = onehot_src_map.size() if self.use_entmax: prob = entmax_bisect(orig_prob, 1.2) else: prob = torch.softmax(orig_prob, 1) # Probability of copying p(z=1) batch. p_copy = torch.sigmoid(self.linear_copy(hidden)) # Probability of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy) mul_attn = torch.mul(attn, p_copy) copy_prob = torch.bmm( mul_attn.view(batch, -1, slen), # batch size x tgt len x src len onehot_src_map.float()) # batch size x src len x cvocab copy_prob = copy_prob.contiguous().view(-1, cvocab) return out_prob, copy_prob
def forward(self): self.Y = entmax_bisect(self.X, self.alpha, dim=-1, n_iter=self.n_iter)
def log_entmax(*args, **kwargs): return torch.log(entmax_bisect(*args, **kwargs))
def eval_singletoken(model, args, dataset_paths, config, top_k=1, top_p=0.0, t=1.0, train_iter=None, batch_size=None): alpha_entmax = args.alpha_entmax batch_size = batch_size if batch_size is not None else args.batch_size_singletoken datasets = get_datasets(dataset_paths, max_len=batch_size) eval_sampler = SequentialSampler(datasets[args.eval_split]) eval_dataloader = DataLoader( datasets[args.eval_split], sampler=eval_sampler, batch_size=1) model.eval() logging_outputs = [] predicted_tokens = [] target_tokens = [] with torch.no_grad(): for i, batch in tqdm(enumerate(eval_dataloader), desc="Evaluating", total=len(eval_dataloader)): longer_sample = batch[0].to(args.gpu) inp = longer_sample[:, :args.batch_size_singletoken] model_output = model(input_ids=inp) target = longer_sample[:, 1:] logits = model_output[0] log_softmax_probs = F.log_softmax(logits, dim=-1) nll = F.nll_loss(log_softmax_probs[0], target[0], reduction='sum') true_token_logits = - \ F.nll_loss(logits[0], target[0], reduction='none') if alpha_entmax is False: filtered_logits = top_k_top_p_filtering( logits.squeeze(0), top_k=args.top_k, top_p=args.top_p).unsqueeze(0) prev = F.softmax( filtered_logits.view(filtered_logits.shape[1:]), dim=-1).multinomial(num_samples=1).unsqueeze(0).squeeze(-1) probs = F.softmax(filtered_logits, dim=-1) else: probs = entmax_bisect(logits, torch.tensor( [args.alpha], requires_grad=True, device=torch.device(args.gpu)).float()) arange = np.arange(logits.size(1)) next_token_probs = probs[:, arange, target.squeeze().tolist()] voc_sizes = probs.size(-1) smoothed_nll = -torch.mean(torch.log( (next_token_probs + args.laplas_eps) / (1 + args.laplas_eps * voc_sizes) )) pred = probs.view(-1, probs.size(-1) ).multinomial(num_samples=1).view(probs.shape[:-1]) predicted_tokens.extend(pred.view(-1).tolist()) ntokens = inp.numel() rep_logits = torch.zeros_like(logits) rep_logits[:, arange, pred.squeeze().tolist()] = 1 logging_output = TrainingMetrics.ranking_metrics( rep_logits[0].float(), true_token_logits, None, ntokens, target[0]) logging_output['loss'] = nll.item() logging_output['smoothed_nll_loss'] = smoothed_nll.item() logging_output['normalizer'] = ntokens logging_output['sample_size'] = ntokens logging_output['ntokens'] = ntokens logging_output['js_div'] = jensen_shannon_divergence( probs, target).mean().item() if args.token_loss == 'alpha_entmax': loss = ((probs - F.one_hot(target, num_classes=probs.size(-1))) * logits).sum(-1) loss += alpha_entropy(probs, args.alpha) logging_output['alpha_entmax_loss'] = loss.mean().item() logging_outputs.append(logging_output) # for human uniq target_tokens.extend(target.view(-1).tolist()) logging_average = CrossEntropyCriterionWCustomMetrics.aggregate_logging_outputs( logging_outputs) logging_average['e_ppl'] = np.exp( np.mean([x['smoothed_nll_loss'] for x in logging_outputs])) # aggregate_logging_outputs does division by log(2) of loss logging_average['ppl'] = 2**logging_average['loss'] logging_average['human_uniq'] = len(set(target_tokens)) logging_average['uniq'] = len(set(predicted_tokens)) logging_average['wrep'] = np.mean( [v for k, v in logging_average.items() if k.startswith('wrong_repeat')]) logging_average['rep'] = np.mean( [v for k, v in logging_average.items() if k.startswith('repeat')]) logging_average['js_div'] = np.mean([x['js_div'] for x in logging_outputs]) if args.token_loss == 'alpha_entmax': logging_average['alpha_entmax_loss'] = np.mean( [x['alpha_entmax_loss'] for x in logging_outputs]) save_singletoken_sampling_metrics( logging_average, config.to_dict(), args, top_k=top_k, top_p=top_p, train_iter=train_iter) return logging_average
def sample_sequence(model, prefix_batch, prefix_length, continuation_length, num_samples=1, top_k=0, top_p=0.0, temperature=1.0, alpha_entmax=False, output_prefix_hidden=False, repetition_penalty=1.0, **kwargs): continuation_logits = [] context = prefix_batch context = torch.cat([context] * num_samples, 0) assert context.size(1) == prefix_length prev = context output = context past = None log_probs = torch.zeros( (num_samples * prefix_batch.size(0), continuation_length)) policy_pis = [] for i in range(continuation_length): logits, past = model(input_ids=prev, past=past)[:2] if i == 0 and output_prefix_hidden: prefix_hidden = out[2] logits = logits[:, -1, :] logits = logits / temperature if repetition_penalty != 1.0: for ex_id, pert_logits in enumerate(logits): for token_idx in set(output[ex_id].tolist()): if pert_logits[token_idx] < 0: pert_logits[token_idx] *= repetition_penalty else: pert_logits[token_idx] /= repetition_penalty if alpha_entmax is False: if top_k == 1 and top_p == 0: filtered_logits = logits prev = logits.float().argmax(dim=1, keepdim=True) else: filtered_logits = top_k_top_p_filtering( logits, top_k=top_k, top_p=top_p) prev = F.softmax( filtered_logits, dim=- 1).multinomial( num_samples=1) #log_prob = F.log_softmax(filtered_logits, dim=-1) log_prob = F.log_softmax(logits, dim=-1) else: alpha = kwargs.get('alpha', 1.0) prob = entmax_bisect( logits, torch.tensor( [alpha], requires_grad=True, device=logits.device).float()) log_prob = torch.log(prob) prev = prob.multinomial(num_samples=1) filtered_logits = logits continuation_logits.append(logits) output = torch.cat((output, prev), dim=1) arange = np.arange(filtered_logits.size(0)) next_token_logit = filtered_logits[arange, prev.squeeze().tolist()].squeeze() next_token_log_prob = log_prob[arange, prev.squeeze().tolist()].squeeze() log_probs[:, i] = next_token_log_prob policy_pis.append(log_prob.squeeze()) policy_pis = torch.stack(policy_pis, 1) continuation_logits = torch.stack(continuation_logits, 1) if output_prefix_hidden: result = ( output, log_probs, continuation_logits, policy_pis, prefix_hidden) else: result = (output, log_probs, continuation_logits, policy_pis) return result
def _generate_beam(self, src_enc, src_mask, beam_size, length_penalty=0.0, early_stopping=False, min_len=0, max_len=200, trigram_blocking=False, return_all=False, src_map=None, src_tgt_vocab_map=None): """ Decode a sentence given initial start. `x`: - LongTensor(bs, slen) <EOS> W1 W2 W3 <EOS> <PAD> <EOS> W1 W2 W3 W4 <EOS> `lengths`: - LongTensor(bs) [5, 6] `positions`: - False, for regular "arange" positions (LM) - True, to reset positions from the new generation (MT) `langs`: - must be None if the model only supports one language - lang_id if only one language is involved (LM) - (lang_id1, lang_id2) if two languages are involved (MT) """ # check inputs assert src_enc.size(0) == src_mask.size(0) assert beam_size >= 1 # batch size / number of words bs = len(src_mask) n_words = self.n_words if not self.use_copy else self.n_words + src_tgt_vocab_map.shape[ 1] # expand to beam size the source latent representations / source lengths src_enc = src_enc.unsqueeze( 1).expand((bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size, ) + src_enc.shape[1:]) src_mask = src_mask.unsqueeze(1).expand( (bs, beam_size) + src_mask.shape[1:]).contiguous().view((bs * beam_size, ) + src_mask.shape[1:]) if src_tgt_vocab_map is not None: src_tgt_vocab_map = src_tgt_vocab_map.unsqueeze( 1).expand((bs, beam_size) + src_tgt_vocab_map.shape[1:]).contiguous().view( (bs * beam_size, ) + src_tgt_vocab_map.shape[1:]) if src_map is not None: src_map = src_map.unsqueeze(1).expand( (bs, beam_size) + src_map.shape[1:]).contiguous().view((bs * beam_size, ) + src_map.shape[1:]) # src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1) # generated sentences (batch with beam current hypotheses) generated = src_enc.new(bs * beam_size, max_len) # upcoming output generated.fill_(self.pad_index) # fill upcoming ouput with <PAD> generated[:, 0].fill_(self.bos_index) # we use <EOS> for <BOS> everywhere # generated hypotheses generated_hyps = [ BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs) ] trigram_set = [set() for _ in range(bs * beam_size)] # scores for each sentence in the beam beam_scores = src_enc.new(bs, beam_size).fill_(0) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view(-1) # current position cur_len = 1 # cache compute states cache = {'slen': 0} # done sentences done = [False for _ in range(bs)] while cur_len < max_len: # compute word scores tensor, _ = self.fwd( x=generated[:, :cur_len] if not self.use_copy else generated[:, :cur_len].masked_fill( generated[:, :cur_len].gt(self.n_words - 1), 0), src_enc=src_enc, src_mask=src_mask, cache=cache, src_map=src_map) if self.use_copy: tensor = torch.cat(tensor, 1) scores, _ = model_utils.collapse_copy_scores( scores=tensor, src_tgt_vocab_map=src_tgt_vocab_map, vocab_size=self.n_words) scores[:, self.n_words] = 0 else: assert tensor.size() == (bs * beam_size, 1, self.n_words) scores = tensor[:, -1, :] # (bs * beam_size, dim) scores[:, 0] = -float('Inf') if not self.use_copy else 0 scores[:, self.pad_index] = -float('Inf') if not self.use_copy else 0 scores[:, self.bos_index] = -float('Inf') if not self.use_copy else 0 if cur_len < min_len: scores[:, self. eos_index] = -float('Inf') if not self.use_copy else 0 if self.use_copy: scores = (scores + 1e-10).log() elif self.use_entmax: scores = torch.log(entmax_bisect(scores, 1.2) + 1e-10) else: scores = F.log_softmax(scores, dim=-1) # (bs * beam_size, n_words) assert scores.size() == (bs * beam_size, n_words), (scores.shape, (bs * beam_size, n_words)) # select next words with scores _scores = scores + beam_scores[:, None].expand_as( scores) # (bs * beam_size, n_words) _scores = _scores.view(bs, beam_size * n_words) # (bs, beam_size * n_words) next_scores, next_words = torch.sort(_scores, dim=1, descending=True) assert next_scores.size() == next_words.size() == (bs, n_words * beam_size) # next batch beam content # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch) next_batch_beam = [] # for each sentence for sent_id in range(bs): # if we are done with this sentence done[sent_id] = done[sent_id] or generated_hyps[ sent_id].is_done(next_scores[sent_id].max().item()) if done[sent_id]: next_batch_beam.extend([(0, self.pad_index, 0)] * beam_size) # pad the batch continue # next sentence beam content next_sent_beam = [] n_add = 0 # next words for this sentence for idx, value in zip(next_words[sent_id], next_scores[sent_id]): # get beam and word IDs beam_id = idx // n_words word_id = idx % n_words if trigram_blocking and cur_len > 2: trigram = tuple(generated[sent_id * beam_size + beam_id, cur_len - 2:cur_len].tolist() + [word_id.item()]) if trigram in trigram_set[sent_id * beam_size + beam_id]: continue # end of sentence, or next word if word_id == self.eos_index or cur_len + 1 == max_len: n_add += 1 generated_hyps[sent_id].add( generated[sent_id * beam_size + beam_id, :cur_len].clone(), value.item()) else: next_sent_beam.append( (value, word_id, sent_id * beam_size + beam_id)) if trigram_blocking and cur_len > 2: trigram_set[sent_id * beam_size + beam_id].add(trigram) # the beam for next step is full if len(next_sent_beam) == beam_size or ( cur_len + 1 == max_len and n_add == beam_size): break # update next beam content assert len(next_sent_beam ) == 0 if cur_len + 1 == max_len else beam_size if len(next_sent_beam) == 0: next_sent_beam = [(0, self.pad_index, 0) ] * beam_size # pad the batch next_batch_beam.extend(next_sent_beam) assert len(next_batch_beam) == beam_size * (sent_id + 1) # sanity check / prepare next batch assert len(next_batch_beam) == bs * beam_size beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) beam_words = generated.new([x[1] for x in next_batch_beam]) beam_idx = generated.new([x[2] for x in next_batch_beam]).long() # re-order batch and internal states trigram_set = [ deepcopy(trigram_set[x[2]]) for x in next_batch_beam ] generated = generated[beam_idx, :] generated[:, cur_len] = beam_words for k in cache.keys(): if k != 'slen': cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) # update current length cur_len = cur_len + 1 # stop when we are done with each sentence if all(done): break if return_all: return generated_hyps # select the best hypotheses tgt_len = src_enc.new(bs).long() best = [] best_scores = [] for i, hypotheses in enumerate(generated_hyps): best_score, best_hyp = max(hypotheses.hyp, key=lambda x: x[0]) tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol best.append(best_hyp) best_scores.append(best_score) # generate target batch decoded = src_enc.new(tgt_len.max().item(), bs).fill_(self.pad_index) for i, hypo in enumerate(best): decoded[:tgt_len[i] - 1, i] = hypo decoded[tgt_len[i] - 1, i] = self.eos_index # sanity check assert (decoded == self.eos_index).sum() == bs return decoded.transpose( 0, 1).cpu().numpy(), best_scores, tgt_len.cpu().numpy()
def _generate(self, src_enc, src_mask, max_len=200, min_len=0, top_p=None, src_map=None, src_tgt_vocab_map=None): """ Decode a sentence given initial start. `x`: - LongTensor(bs, slen) <EOS> W1 W2 W3 <EOS> <PAD> <EOS> W1 W2 W3 W4 <EOS> `lengths`: - LongTensor(bs) [5, 6] `positions`: - False, for regular "arange" positions (LM) - True, to reset positions from the new generation (MT) `langs`: - must be None if the model only supports one language - lang_id if only one language is involved (LM) - (lang_id1, lang_id2) if two languages are involved (MT) """ # input batch bs = len(src_mask) assert src_enc.size(0) == bs # generated sentences generated = src_mask.new(bs, max_len) # upcoming output generated.fill_(self.pad_index) # fill upcoming ouput with <PAD> generated[:, 0].fill_(self.bos_index) # we use <EOS> for <BOS> everywhere # current position / max lengths / length of generated sentences / unfinished sentences cur_len = 1 # gen_len = torch.ones(bs).to(src_mask.device).long() unfinished_sents = torch.ones(bs).to( src_mask.device).long() #src_len.clone().fill_(1) all_scores = torch.zeros(bs).to(src_mask.device) # cache compute states cache = {'slen': 0} while cur_len < max_len: # compute word scores tensor, _ = self.fwd( x=generated[:, :cur_len] if not self.use_copy else generated[:, :cur_len].masked_fill( generated[:, :cur_len].gt(self.n_words - 1), 0), src_enc=src_enc, src_mask=src_mask, cache=cache, src_map=src_map) if self.use_copy: tensor = torch.cat(tensor, 1) scores, _ = model_utils.collapse_copy_scores( scores=tensor, src_tgt_vocab_map=src_tgt_vocab_map, vocab_size=self.n_words) scores[:, self.n_words] = 0.0 else: assert tensor.size() == (bs, 1, self.n_words), (cur_len, max_len, src_enc.size(), tensor.size(), (1, bs, self.n_words)) scores = tensor[:, -1, :] # (bs, dim) # scores = self.pred_layer.get_scores(tensor) # (bs, n_words) scores[:, 0] = -float('Inf') if not self.use_copy else 0 scores[:, self.pad_index] = -float('Inf') if not self.use_copy else 0 scores[:, self.bos_index] = -float('Inf') if not self.use_copy else 0 if cur_len < min_len: scores[:, self. eos_index] = -float('Inf') if not self.use_copy else 0 # select next words: sample or greedy if top_p: if self.use_copy: next_words = torch.multinomial( model_utils.top_k_top_p_filtering( scores, top_k=0.0, top_p=top_p if top_p else 0.0, filter_value=0.0, need_softmax=False), 1).squeeze(1) next_scores = (scores + 1e-10).log().gather( 1, next_words.unsqueeze(1)).squeeze(1) else: next_words = torch.multinomial( F.softmax(model_utils.top_k_top_p_filtering( scores, top_k=0.0, top_p=top_p if top_p else 0.0), dim=1), 1).squeeze(1) next_scores = scores.log_softmax(1).gather( 1, next_words.unsqueeze(1)).squeeze(1) else: if self.use_copy: next_scores, next_words = (scores + 1e-10).log().max(1) elif self.use_entmax: next_scores, next_words = (entmax_bisect(scores, 1.2) + 1e-10).log().max(1) else: next_scores, next_words = scores.log_softmax(1).max(1) assert next_words.size() == (bs, ) # update generations / lengths / finished sentences / current length generated[:, cur_len] = next_words * unfinished_sents + self.pad_index * ( 1 - unfinished_sents) all_scores = all_scores + next_scores * unfinished_sents.float() # gen_len.add_(unfinished_sents) unfinished_sents.mul_(next_words.ne(self.eos_index).long()) cur_len = cur_len + 1 # stop when there is a </s> in each sentence, or if we exceed the maximul length if unfinished_sents.max() == 0: break # add <EOS> to unfinished sentences if cur_len == max_len: generated[:, -1].masked_fill_(unfinished_sents.bool(), self.eos_index) # sanity check assert (generated == self.eos_index).sum() == bs return generated[:, 1:cur_len].cpu().numpy(), all_scores.cpu().numpy( ) #, gen_len
def forward( self, query, key, value, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None, before_softmax=False, need_head_weights=False, ): """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv: return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat( (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if 'prev_key' in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat([ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1) ], dim=1) q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if 'prev_key' in saved_state: prev_key = saved_state['prev_key'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: k = torch.cat((prev_key, k), dim=1) if 'prev_value' in saved_state: prev_value = saved_state['prev_value'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: v = torch.cat((prev_value, v), dim=1) key_padding_mask = self._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=saved_state.get('prev_key_padding_mask', None), batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_key_padding_mask'] = key_padding_mask self._set_input_buffer(incremental_state, saved_state) src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.shape == torch.Size( []): key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat([ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask) ], dim=1) if not bmm_fp16_support: q = q.float() k = k.float() v = v.float() attn_weights = torch.bmm(q, k.transpose(1, 2)) if not bmm_fp16_support: attn_weights = attn_weights.type_as(query) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list( attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v # 1 if not self.cur_san_active: self.div = 0 if self.div > 0: top_k = int(torch.ceil(torch.Tensor([src_len / self.div]))) if top_k < self.lb: top_k = self.lb if top_k > src_len: top_k = src_len else: top_k = -self.div if top_k > src_len: top_k = src_len # 2 # print('attn_weights ', attn_weights.size()) if self.entmax: from entmax import sparsemax, entmax15, entmax_bisect if self.entmax == 1: attn_weights = sparsemax(attn_weights.float(), dim=-1).type_as(attn_weights) elif self.entmax == 2: attn_weights = entmax15(attn_weights.float(), dim=-1).type_as(attn_weights) elif self.entmax == 3: attn_weights_float = entmax_bisect( attn_weights.float(), dim=-1).type_as(attn_weights) else: if self.div: vk, _ = torch.topk(attn_weights, top_k) # print(value) tk = vk[:, :, -1].unsqueeze(2).expand_as(attn_weights) mask_k = torch.lt(attn_weights, tk) attn_weights = attn_weights.masked_fill( mask_k, float('-inf')).type_as(attn_weights) attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) if not bmm_fp16_support: attn_probs = attn_probs.float( ) # bsz * self.num_heads, tgt_len, src_len attn = torch.bmm(attn_probs, v) if not bmm_fp16_support: attn = attn.type_as(query) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) if need_weights: attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) else: attn_weights = None return attn, attn_weights