def _forward_expanded(self, x, incremental_state): '''Turn the convolution filters into band matrices and do matrix multiplication. This is faster when the sequence is short, but less memory efficient. This is not used in the decoder during inference. ''' T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size if self.weight_linear: if self.sample_lc_sp: weight = self.weight_linear( torch.mean(torch.mean(x, dim=0), dim=1).unsqueeze(1)).view(B, H, K) # B,H,K else: weight = self.weight_linear(torch.mean(x, dim=0)).view(B, H, K) # B,H,K else: weight = self.weight.view(H, K) if self.weight_softmax: if self.weight_linear: weight = utils.softmax( weight, dim=2, onnx_trace=self.onnx_trace).type_as(weight) else: weight = utils.softmax( weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight) if self.weight_linear: weight = weight.unsqueeze(0).expand(T, B, H, K).reshape(T * B, H, K).contiguous() else: weight = weight.view(1, H, K).expand(T * B, H, K).contiguous() weight = weight.view(T, B * H, K).transpose(0, 1) # B*h T,k x = x.view(T, B * H, R).transpose(0, 1) # (B*H,T,R) P = self.padding_l if K > T and P == K - 1: weight = weight.narrow(2, K - T, T) K, P = T, T - 1 # turn the convolution filters into band matrices weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) weight_expanded.as_strided((B * H, T, K), (T * (T + K - 1), T + K, 1)).copy_(weight) weight_expanded = weight_expanded.narrow(2, P, T) # (B*H,T,T) weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training) if bmm_fp16_support: output = torch.bmm(weight_expanded, x) # (B*H,T,R) else: output = torch.bmm(weight_expanded.float(), x.float()).type_as(weight) output = output.transpose(0, 1).contiguous().view(T, B, C) return output
def get_normalized_probs(self, ctc_logits, logits, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" if log_probs: ctc_res = utils.log_softmax(ctc_logits.float(), dim=-1) res = utils.log_softmax(logits.float(), dim=-1) else: ctc_res = utils.softmax(ctc_logits.float(), dim=-1) res = utils.softmax(logits.float(), dim=-1) ctc_res.batch_first = True res.batch_first = True return ctc_res, res
def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits_ctc = net_output["logits_ctc"] logits = net_output["logits"] if log_probs: ctc_res = utils.log_softmax(logits_ctc.float(), dim=-1) res = utils.log_softmax(logits.float(), dim=-1) else: ctc_res = utils.softmax(logits_ctc.float(), dim=-1) res = utils.softmax(logits.float(), dim=-1) return ctc_res, res
def cross_attentive_loss(self, teacher_states, student_states, teacher_masking, student_masking, eps=1e-6): x = teacher_states.transpose(0, 1) # from T X B X D to B X T X D y = student_states.transpose(0, 1) if self.cross_attentive_loss_with_norm: x = x / (x.norm(dim=2, keepdim=True) + eps) y = y / (y.norm(dim=2, keepdim=True) + eps) dim = x.size(-1) # lengths: batch X seqLen sim_scores_xy = torch.bmm(x, y.transpose(1, 2)) # batch X lenx X leny ] if y.dtype == torch.float16: sim_scores_xy = sim_scores_xy.float() y = y.float() x = x.float() if teacher_masking != []: assert len(teacher_masking) == 1 sim_scores_xy = sim_scores_xy.masked_fill( teacher_masking[0].unsqueeze(-1), float("-inf")) if student_masking != []: sim_scores_xy = sim_scores_xy.masked_fill( student_masking[0].unsqueeze(1), float("-inf")) # do masking y_weights = utils.softmax(sim_scores_xy, dim=-1) if teacher_masking != []: y_weights = y_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0) x_reconstruct_from_y = torch.bmm(y_weights, y) sim_scores_xx = torch.bmm(x, x.transpose(1, 2)) # batch X lenx X lenx ] x_weights = utils.softmax(sim_scores_xx, dim=-1) if teacher_masking != []: x_weights = x_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0) # no gradient for teacher state x_reconstruct_from_x = torch.bmm(x_weights, x).detach() cost = (x_reconstruct_from_x - x_reconstruct_from_y).norm(dim=2) if teacher_masking != []: cost = cost.masked_fill(teacher_masking[0], 0) if not self.cross_attentive_loss_with_norm: cost = cost / dim return cost
def get_normalized_probs(self, net_output, log_probs, sample, adaptive_softmax=True): """Get normalized probabilities (or log probs) from a net's output.""" if adaptive_softmax: if hasattr( self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out # judge for extend the previous logits = net_output[0] if isinstance(net_output, list) else net_output if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out ''' logits_list = net_output[0] if log_probs: return [utils.log_softmax( logits, dim=-1, onnx_trace=self.onnx_trace) for logits in logits_list][0] else: return [utils.softmax( logits, dim=-1, onnx_trace=self.onnx_trace) for logits in logits_list][0] ''' logits = net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_normalized_probs_with_temperature(self, net_output, log_probs, sample=None, temperature=1.): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] / temperature if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def _forward_unfolded(self, x, incremental_state): '''The conventional implementation of convolutions. Unfolding the input by having a window shifting to the right.''' T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size weight = self.weight.view(H, K) if incremental_state is not None: input_buffer = self._get_input_buffer(incremental_state) if input_buffer is None: input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) x_unfold = x_unfold.view(T*B*H, R, -1) else: # unfold the input: T x B x C --> T' x B x C x K x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0) x_unfold = x_unfold.view(T*B*H, R, K) if self.weight_softmax: weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight) if incremental_state is not None: weight = weight[:, -x_unfold.size(2):] K = weight.size(1) weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) weight = self.weight_dropout_module(weight) output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 output = output.view(T, B, C) return output
def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output["encoder_out"] if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1)
def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "target" in sample target = sample["target"] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: #print('Fairseq Decoder: net_output size: {}'.format(net_output.size())) if use_ort_backend: return utils.log_softmax(net_output, dim=-1, onnx_trace=self.onnx_trace) return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "source" in sample source = sample["source"] else: source = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=source) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def forward(self, query, value, key_padding_mask=None, state=None): # projected_query: 1 x bsz x embed_dim projected_query = self.query_proj(query).unsqueeze(0) key = self.value_proj(value) # len x bsz x embed_dim if self.normalize: # normed_v = g * v / ||v|| normed_v = self.g * self.v / torch.norm(self.v) attn_scores = (normed_v * torch.tanh(projected_query + key + \ self.b)).sum(dim=2) # len x bsz else: attn_scores = v * torch.tanh(projected_query + key).sum(dim=2) if key_padding_mask is not None: attn_scores = attn_scores.float().masked_fill_( key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back attn_scores = utils.softmax(attn_scores, dim=0, onnx_trace=self.onnx_trace).type_as( attn_scores) # len x bsz # sum weighted value. context: bsz x value_dim context = (attn_scores.unsqueeze(2) * value).sum(dim=0) next_state = attn_scores return context, attn_scores, next_state
def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output if log_probs: return utils.log_softmax(logits, dim=-1) else: return utils.softmax(logits, dim=-1)
def generate(self, models, sample, **unused): """Generate a batch of inferences.""" model = models[0] # encoder_output = model.encoder(tbc=False, **sample["net_input"]) # alphas = CIFFcModel.get_alphas(encoder_output) # decode_length = torch.round(alphas.sum(-1)).int() # _alphas, num_output = model.resize(alphas, decode_length, noise=0.0) # # padding_mask = ~utils.sequence_mask(decode_length).bool() # cif_outputs = model.cif(encoder_output['encoder_out'][:, :, :-1], _alphas) # hidden = model.proj(cif_outputs) # logits_ac = model.to_vocab_ac(hidden) # # infer_threash = self.infer_threshold if self.infer_threshold else model.args.infer_threash # for i in range(1): # logits, gold_embedding, pred_mask, token_mask = model.bert_forward( # hidden, logits_ac, padding_mask, None, 0.0, # threash=infer_threash) # logits = self.args.lambda_am * logits_ac + model.args.lambda_lm * logits # probs = utils.softmax(logits.float(), dim=-1) net_output = model(**sample["net_input"]) logits = net_output['logits'] probs = utils.softmax(logits.float(), dim=-1) decode_length = net_output['len_logits'] res = [] for distribution, length in zip(probs, decode_length): result = distribution.argmax(-1) score = 0.0 res.append([{'tokens': result[:length], "score": score}]) return res
def _forward_expanded(self, x, incremental_state): """Turn the convolution filters into band matrices and do matrix multiplication. This is faster when the sequence is short, but less memory efficient. This is not used in the decoder during inference. """ T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size weight = self.weight.view(H, K) if self.weight_softmax: weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as( weight ) weight = weight.view(1, H, K).expand(T * B, H, K).contiguous() weight = weight.view(T, B * H, K).transpose(0, 1) x = x.view(T, B * H, R).transpose(0, 1) P = self.padding_l if K > T and P == K - 1: weight = weight.narrow(2, K - T, T) K, P = T, T - 1 # turn the convolution filters into band matrices weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) weight_expanded.as_strided((B * H, T, K), (T * (T + K - 1), T + K, 1)).copy_( weight ) weight_expanded = weight_expanded.narrow(2, P, T) weight_expanded = self.weight_dropout_module(weight_expanded) output = torch.bmm(weight_expanded, x) output = output.transpose(0, 1).contiguous().view(T, B, C) return output
def bert_forward(self, hidden, logits_ac, padding_mask, input_ids=None, gold_rate=0.0, threash=0.8): """ """ device = hidden.device if self.training: token_mask = input_ids.ne(self.tgt_dict.cls()) * \ input_ids.ne(self.tgt_dict.sep()) * \ input_ids.ne(self.tgt_dict.pad()) gold_embedding = self.bert.embeddings.word_embeddings(input_ids) pred_mask = (torch.rand(input_ids.size(), device=device) > gold_rate) * token_mask else: # infer token_mask = F.pad(~padding_mask, [1, 1, 0, 0], value=0) probs = F.pad(utils.softmax(logits_ac.float(), dim=-1), [0, 0, 1, 1, 0, 0], value=0) confident, preds = probs.max(-1) preds_ids = pred2bert_input(preds, token_mask) # preds = torch.where(token_mask, preds, input_ids) gold_embedding = self.bert.embeddings.word_embeddings(preds_ids) pred_mask = (confident < threash) * token_mask hidden_mix = torch.where(pred_mask[:, :, None].repeat(1, 1, hidden.size(-1)), F.pad(hidden, [0, 0, 1, 1, 0, 0], value=0), gold_embedding) attention_mask = padding2attention_mask(padding_mask) embeddings = self.bert.embeddings(inputs_embeds=hidden_mix) encoder_outputs = self.bert.encoder( embeddings, attention_mask=attention_mask[:, None, None, :]) logits = self.to_vocab(encoder_outputs[0]) logits = logits[:, 1:-1, :] return logits, gold_embedding, pred_mask, token_mask
def attn_fn(attn_weights, is_query=False): if attn_mask is not None: attn_weights += attn_mask if key_padding_mask is not None: attn_weights = attn_weights.view(bsz, self.num_heads, src_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_weights = attn_weights.view(bsz * self.num_heads, src_len, src_len) if is_query is True: query_mask = torch.eye( attn_weights.size(-1)).to(attn_weights) * -1e9 query_mask[0][0] = 0.0 attn_weights = attn_weights + query_mask attn_weights = utils.softmax( attn_weights, dim=-1, ).type_as(attn_weights) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) return attn_weights
def get_normalized_probs(self, net_output, log_probs, sample, gs_tau=0.5, gs_hard=False): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0][0] orders = net_output[0][1] if log_probs: return (utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace), self.gumbel_softmax(orders, gs_tau=gs_tau, gs_hard=gs_hard, dim=-1)) else: return (utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace), self.gumbel_softmax(orders, gs_tau=gs_tau, gs_hard=gs_hard, dim=-1))
def get_normalized_probs(self, net_output, log_probs, retrun_ctc=False): """Get normalized probabilities (or log probs) from a net's output.""" logits_ctc = net_output["logits_ctc"] logits = net_output["logits"] if log_probs: res_ctc = utils.log_softmax(logits_ctc.float(), dim=-1) res = utils.log_softmax(logits.float(), dim=-1) else: res_ctc = utils.softmax(logits_ctc.float(), dim=-1) res = utils.softmax(logits.float(), dim=-1) res_ctc.batch_first = True res.batch_first = True if retrun_ctc: return res_ctc, res else: return res
def get_normalized_probs_w2v(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" print(net_output.keys()) logits = net_output["wav2vec_logits"] if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1)
def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] is_copy = 'p_copy' in net_output[1].keys( ) and net_output[1]['p_copy'] is not None # print(net_output[1]['attn']) if is_copy and False: p_copy = net_output[1]['p_copy'] if 'net_input' in sample.keys(): enc_seq_ids = sample['net_input']['src_tokens'] else: # for decode step enc_seq_ids = sample['src_tokens'] enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat( 1, net_output[1]['copy_attn'].size(1), 1) generate_prob = utils.softmax( logits, dim=-1, onnx_trace=self.onnx_trace) * (1 - p_copy) copy_prob = net_output[1]['copy_attn'] * p_copy final = generate_prob.scatter_add(2, enc_seq_ids, copy_prob) if log_probs: return torch.log(final + 1e-15) else: return final else: if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def forward(self, x, need_attention_weights=False): # Attention scorees: alpha = self.w2(self.w1(x)) # B, Tt, Ts, 1 alpha = utils.softmax(alpha, dim=2).type_as(alpha) x = x.permute(0, 1, 3, 2) x = torch.matmul(x, alpha).squeeze(-1) if need_attention_weights: return x, alpha.squeeze(-1) return x, None
def get_normalized_probs_cif(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits = self.get_logits_cif(net_output) if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1)
def one_step(self, x, need_attention_weights=False): x = x[:, -1:] # B, 1, Ts, C alpha = self.w2(self.w1(x)) # B, 1, Ts, 1 alpha = utils.softmax(alpha, dim=2) x = x.permute(0, 1, 3, 2) x = torch.matmul(x, alpha).squeeze(-1) if need_attention_weights: return x, alpha.squeeze(-1) return x, None
def forward(self, x, need_attention_weights=False): # Attention scorees: B, Tt, Ts, C = x.size() alpha = self.w2(self.w1(x)) # B, Tt, Ts, 1 # for every (t,j) allow first j mask = torch.triu(utils.fill_with_neg_inf(x.new(Ts, Ts)), 1).type_as(alpha) alpha = alpha.permute(0,1,3,2) + mask.unsqueeze(0).unsqueeze(0) # B,Tt,Ts,Ts alpha = utils.softmax(alpha, dim=-1) x = torch.matmul(alpha, x) return x, None
def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output[0] if log_probs: res = utils.log_softmax(logits.float(), dim=-1) else: res = utils.softmax(logits.float(), dim=-1) res.batch_first = True return res
def forward(self, x, need_attention_weights=False): # Attention scorees: B, Tt, Ts, C = x.size() alpha = self.w2(self.w1(x)) # B, Tt, Ts, 1 mask = torch.triu(utils.fill_with_neg_inf(x.new(Tt, Ts)), self.waitk) alpha = utils.softmax(alpha + mask.unsqueeze(0).unsqueeze(-1), dim=2).type_as(alpha) x = x.permute(0,1,3,2) x = torch.matmul(x, alpha).squeeze(-1) if need_attention_weights: return x, alpha.squeeze(-1) return x, None
def gumbel_softmax(self, logits, gs_tau=0.5, gs_hard=False, dim=-1): if not gs_hard: prob = utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) prob_clamp = torch.clamp( prob, self.clamp_value, 1. - (self.decoder_max_order - 1) * self.clamp_value) logprob = torch.log(prob_clamp if self.gs_clamp else prob) gs = F.gumbel_softmax( logprob, tau=gs_tau, hard=False, ) else: prob = utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) prob_clamp = torch.clamp( prob, self.clamp_value, 1. - (self.decoder_max_order - 1) * self.clamp_value) max_idx = torch.argmax(logits, -1, keepdim=True) one_hot = logits.new_zeros(logits.size()) gs = one_hot.scatter(-1, max_idx, 1) return gs, prob, prob_clamp
def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" # print('enter normalized.') if 'net_input' in sample.keys(): enc_seq_ids = sample['net_input']['src_tokens'] else: enc_seq_ids = sample['src_tokens'] # wvocab_size = net_output[0].size(2) # batch_size = enc_seq_ids.size(0) # seq_len = enc_seq_ids.size(1) # one_hot = torch.zeros(batch_size, seq_len, wvocab_size).cuda().scatter_(dim=2, index=enc_seq_ids.unsqueeze(-1), value=1) # # copy_probs = torch.matmul(net_output[1]['attn'], one_hot) # final_dist = vocab_dist.scatter_add(1, encoder_batch_extend_vocab, attn_dist) if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: generate = utils.softmax( logits, dim=-1, onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate'] copy = net_output[1]['attn'] * (1 - net_output[1]['copy_or_generate']) enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat( 1, net_output[1]['attn'].size(1), 1) final = generate.scatter_add(2, enc_seq_ids, copy) final = torch.log(final + 1e-15) return final else: generate = utils.log_softmax( logits, dim=-1, onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate'] copy = net_output[1]['attn'] * (1 - net_output[1]['copy_or_generate']) enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat( 1, net_output[1]['attn'].size(1), 1) final = generate.scatter_add(2, enc_seq_ids, copy) return final
def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] copy_scores = net_output[1]["copy_scores"] p_copy = net_output[1]["p_copy"].float() if log_probs: return torch.log((1 - p_copy) * utils.softmax(logits, dim=-1) + p_copy * copy_scores.float()) else: return (1 - p_copy) * utils.softmax( logits, dim=-1) + p_copy * copy_scores.float()