def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "target" in sample target = sample["target"] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: #print('Fairseq Decoder: net_output size: {}'.format(net_output.size())) if use_ort_backend: return utils.log_softmax(net_output, dim=-1, onnx_trace=self.onnx_trace) return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_normalized_probs(self, ctc_logits, logits, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" if log_probs: ctc_res = utils.log_softmax(ctc_logits.float(), dim=-1) res = utils.log_softmax(logits.float(), dim=-1) else: ctc_res = utils.softmax(ctc_logits.float(), dim=-1) res = utils.softmax(logits.float(), dim=-1) ctc_res.batch_first = True res.batch_first = True return ctc_res, res
def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits_ctc = net_output["logits_ctc"] logits = net_output["logits"] if log_probs: ctc_res = utils.log_softmax(logits_ctc.float(), dim=-1) res = utils.log_softmax(logits.float(), dim=-1) else: ctc_res = utils.softmax(logits_ctc.float(), dim=-1) res = utils.softmax(logits.float(), dim=-1) return ctc_res, res
def get_logits(self, net_output): logits = net_output["x"] lprob = utils.log_softmax(logits.float(), dim=-1) lprob = lprob.transpose(0, 1) lprob.batch_first = False return lprob
def get_normalized_probs_with_temperature(self, net_output, log_probs, sample=None, temperature=1.): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] / temperature if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def compute_mlm_loss(self, enc_output, target, data_type=None, reduce=True): lprobs = utils.log_softmax(enc_output, dim=-1, onnx_trace=False) p_lprobs = lprobs.clone() if data_type is not None: data_type = data_type.view(-1, 1).repeat(1, lprobs.size()[1]) data_type = data_type.view(-1, 1) ## de is mono so en to de, source need to subtract 1 data_type = 1 - data_type lprobs = lprobs.view(-1, lprobs.size(-1)) predict_sentence = torch.argmax(lprobs, dim=-1) predict_sentence = predict_sentence.view(target.size()) target = target.view(-1, 1) loss, nll_loss = label_smoothed_nll_loss(lprobs, target, self.eps, ignore_index=self.padding_idx, data_type=data_type) return loss, predict_sentence, p_lprobs
def get_normalized_probs(self, net_output, log_probs, sample, adaptive_softmax=True): """Get normalized probabilities (or log probs) from a net's output.""" if adaptive_softmax: if hasattr( self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out # judge for extend the previous logits = net_output[0] if isinstance(net_output, list) else net_output if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "source" in sample source = sample["source"] else: source = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=source) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output if log_probs: return utils.log_softmax(logits, dim=-1) else: return utils.softmax(logits, dim=-1)
def compute_loss(self, model, net_output, sample, reduce=True): lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) sample['padding_idx'] = self.padding_idx target = model.get_targets(sample, net_output).view(-1, 1) non_pad_mask = target.ne(self.padding_idx) # compute length prediction loss length_lprobs = net_output[1]['predicted_lengths'] length_target = sample['net_input']['prev_output_tokens'].ne( self.padding_idx).sum(-1).unsqueeze(-1) length_loss = -length_lprobs.gather(dim=-1, index=length_target) src_lprobs = utils.log_softmax(net_output[1]['encoder_out'], dim=-1) src_lprobs = src_lprobs.view(-1, src_lprobs.size(-1)) src_target = sample['src_target'].view(-1, 1) src_non_pad_mask = src_target.ne(self.padding_idx) src_nll_loss = -src_lprobs.gather(dim=-1, index=src_target)[src_non_pad_mask] nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] if reduce: nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() length_loss = length_loss.sum() src_nll_loss = src_nll_loss.sum() eps_i = self.eps / lprobs.size(-1) loss = ( 1. - self.eps ) * nll_loss + eps_i * smooth_loss + 0.1 * length_loss + 0.01 * src_nll_loss return loss, nll_loss, length_loss, src_nll_loss
def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out ''' logits_list = net_output[0] if log_probs: return [utils.log_softmax( logits, dim=-1, onnx_trace=self.onnx_trace) for logits in logits_list][0] else: return [utils.softmax( logits, dim=-1, onnx_trace=self.onnx_trace) for logits in logits_list][0] ''' logits = net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_normalized_probs(self, net_output, log_probs, sample, gs_tau=0.5, gs_hard=False): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0][0] orders = net_output[0][1] if log_probs: return (utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace), self.gumbel_softmax(orders, gs_tau=gs_tau, gs_hard=gs_hard, dim=-1)) else: return (utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace), self.gumbel_softmax(orders, gs_tau=gs_tau, gs_hard=gs_hard, dim=-1))
def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output["encoder_out"] if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1)
def compute_loss(self, model, net_output, sample, reduce=True): #get target and generated text target = model.get_targets(sample, net_output).view(-1, 1) ## semantic sim_loss output_tokens = net_output[0] sentence_tok = torch.argmax(utils.log_softmax(output_tokens, dim=-1), -1) # maxpool sentence_txt = self.bpe.decode( self.task.target_dictionary.string(sentence_tok)) ignore_index = self.padding_idx if ignore_index is not None: non_pad_mask = target.ne(ignore_index) target_ig = target[non_pad_mask] target_txt = self.bpe.decode( self.task.target_dictionary.string(target_ig)) print("\n\n## sentence_txt: ", sentence_txt, "\n## target_txt: ", target_txt) lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) target = model.get_targets(sample, net_output).view(-1, 1) loss, nll_loss = label_smoothed_nll_loss( lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, ) return loss, nll_loss
def get_normalized_probs(self, net_output, log_probs, retrun_ctc=False): """Get normalized probabilities (or log probs) from a net's output.""" logits_ctc = net_output["logits_ctc"] logits = net_output["logits"] if log_probs: res_ctc = utils.log_softmax(logits_ctc.float(), dim=-1) res = utils.log_softmax(logits.float(), dim=-1) else: res_ctc = utils.softmax(logits_ctc.float(), dim=-1) res = utils.softmax(logits.float(), dim=-1) res_ctc.batch_first = True res.batch_first = True if retrun_ctc: return res_ctc, res else: return res
def get_normalized_probs_w2v(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" print(net_output.keys()) logits = net_output["wav2vec_logits"] if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1)
def get_normalized_probs_cif(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits = self.get_logits_cif(net_output) if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1)
def get_knn_log_prob(self, queries, tgt, pad_idx): def dist_func(d, k, q, function=None): if not function: # Default behavior for L2 metric is to recompute distances. # Default behavior for IP metric is to return faiss distances. qsize = q.shape if self.metric_type == 'l2': start = time.time() knns_vecs = torch.from_numpy(self.keys[k]).cuda().view( qsize[0], self.k, -1) if self.half: knns_vecs = knns_vecs.half() query_vecs = q.view(qsize[0], 1, qsize[1]).repeat(1, self.k, 1) l2 = torch.sum((query_vecs - knns_vecs.detach())**2, dim=2) return -1 * l2 return d if function == 'dot': qsize = q.shape return (torch.from_numpy(self.keys[k]).cuda() * q.view(qsize[0], 1, qsize[1])).sum(dim=-1) if function == 'do_not_recomp_l2': return -1 * d raise ValueError("Invalid knn similarity function!") # queries are TxBxC # reshape: (TxB)xC qshape = queries.shape queries = queries.view(-1, qshape[-1]) tgt = tgt.contiguous().view(-1) dists, knns = self.get_knns(queries[tgt != pad_idx]) # (T_reducedxB)xK dists = torch.from_numpy(dists).cuda() start = time.time() dists = dist_func(dists, knns, queries[tgt != pad_idx, :], function=self.sim_func) probs = utils.log_softmax(dists, dim=-1) index_mask = torch.eq( torch.from_numpy(self.vals[knns]).long().cuda().squeeze(-1), tgt[tgt != pad_idx].unsqueeze(-1)).float() index_mask[index_mask == 0] = -10000 # for stability index_mask[index_mask == 1] = 0 # (T_reducedxB) yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1).clone() full_yhat_knn_prob = torch.full([qshape[0] * qshape[1]], -10000).cuda() full_yhat_knn_prob[tgt != pad_idx] = yhat_knn_prob # TxBx1 return full_yhat_knn_prob.view(qshape[0], qshape[1], 1)
def forward_train(self, prev_output_tokens, encoder_out, target, **kwargs): print('Target tokens:', prev_output_tokens) # source embeddings src_emb = encoder_out['encoder_out'] # B, Ts, ds # target embeddings: positions = self.embed_positions( prev_output_tokens, incremental_state=None, ) if self.embed_positions is not None else None decoder_mask = prev_output_tokens.eq(self.padding_idx) if not decoder_mask.any(): decoder_mask = None # Build the full grid tgt_emb = self.embed_scale * self.embed_tokens(prev_output_tokens) if positions is not None: tgt_emb += positions tgt_emb = self.embedding_dropout(tgt_emb) batch_size = src_emb.size(0) src_length = src_emb.size(1) tgt_length = tgt_emb.size(1) # build 2d "image" of embeddings src_emb = _expand(src_emb, 1, tgt_length) # B, Tt, Ts, ds tgt_emb = _expand(tgt_emb, 2, src_length) # B, Tt, Ts, dt x = torch.cat((src_emb, tgt_emb), dim=3) # B, Tt, Ts, C=ds+dt x = self.input_dropout(x) if 'embed' in self.controller_input: observations = x # pass through dense convolutional layers encoder_mask = encoder_out['encoder_padding_mask'] x = self.net( x, decoder_mask=decoder_mask, encoder_mask=encoder_mask, incremental_state=None, ) # B, Tt, Ts, C x, _ = self.aggregator(x) # B, Tt, Ts, C x = self.projection(x) if self.projection is not None else x # B, Tt, C if 'feat' in self.controller_input: if 'embed' in self.controller_input: observations = torch.cat((observations, x), dim=-1) else: observations = x # Predict x = self.prediction_dropout(x) x = self.prediction(x) # B, Tt, Ts, V x = utils.log_softmax(x, dim=-1) x = x.view(-1, x.size(-1)).gather( dim=-1, index=target.unsqueeze(-1).expand(-1, -1, src_length).contiguous().view(-1, 1) ).view(batch_size, tgt_length, src_length).permute(1,0,2) # Tt, B, Ts controls, gamma, read_labels, write_labels = self.hmm(observations, x) return x, observations, controls, gamma, read_labels, write_labels
def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output[0] if log_probs: res = utils.log_softmax(logits.float(), dim=-1) else: res = utils.softmax(logits.float(), dim=-1) res.batch_first = True return res
def compute_cross_entropy(self, logits, sample, reduce=True): lprobs = utils.log_softmax(logits, dim=-1, onnx_trace=False) lprobs = lprobs.view(-1, lprobs.size(-1)) target = sample['target'].view(-1) loss = F.nll_loss( lprobs, target, ignore_index=self.padding_idx, reduction='sum' if reduce else 'none', ) return loss
def gumbel_softmax(self, logits, gs_tau=0.5, gs_hard=False, dim=-1): logprob = utils.log_softmax(logits, dim=dim, onnx_trace=self.onnx_trace) logprob = torch.clamp(logprob, math.log(0.1), math.log(0.9)) gs = F.gumbel_softmax( logprob, tau=gs_tau, hard=gs_hard, ) return gs
def forward(self, sample, encoder_out, decoder_out): # First encode the observations if not self.share_embeddings: x = self.observation_grid(sample['src_tokens'], sample['prev_output_tokens']) else: # The writing input grid x = decoder_out[1].clone() # Cumulative ResNet: x = self.net(x) # Cell aggregation # The R/W decisions: x = self.gate_dropout(x) x = self.gate(x) s = F.logsigmoid(x) RWlogits = torch.cat((s, s - x), dim=-1).contiguous().float() with torch.no_grad(): lprobs = decoder_out[0].clone() target = sample['target'] encoder_mask = encoder_out['encoder_padding_mask'] decoder_mask = decoder_out[2] # Gather the ground truth likelihoods B, Tt, Ts, V = lprobs.size() lprobs = utils.log_softmax(lprobs, dim=-1) scores = lprobs.view(-1, V).gather( dim=-1, index=target.unsqueeze(-1).expand(-1, -1, Ts).contiguous().view( -1, 1) # BTtTs ).view(B, Tt, Ts) # Forbid padding positions: # I'm using NLL beware if encoder_mask is not None: scores = scores.masked_fill(encoder_mask.unsqueeze(1), -1000) if decoder_mask is not None: scores = scores.masked_fill(decoder_mask.unsqueeze(-1), -1000) # The Oracle best_context = self.oracle(scores) AP = best_context.add(1).float().mean(dim=1) / Ts print('-', round(AP.mean().data.item(), 2)) Gamma = torch.zeros_like(scores).scatter_( -1, best_context.unsqueeze(-1), 1.0) # B, Tt, Ts # Write beyond the ideal context if self.write_right: Gamma = Gamma.cumsum(dim=-1) write = Gamma[:, 1:] # B, Tt-1, Ts else: write = Gamma[:, 1:].cumsum(dim=-1) # B, Tt-1, Ts read = 1 - write return Gamma, RWlogits[:, :-1], read, write
def get_logits(self, net_output): logits = net_output["x"] logits = logits.transpose(0, 2) logits = logits.reshape(-1, logits.size(-1)) logtis_ctc = net_output["logtis_ctc"] lprob = utils.log_softmax(logtis_ctc.float(), dim=-1) lprob = lprob.transpose(0, 1) lprob.batch_first = False return logits, lprob
def compute_xet_loss(self, logits, gold, padding_idx, reduce=True): lprobs = utils.log_softmax(logits, dim=-1) lprobs = lprobs.view(-1, lprobs.size(-1)) target = gold.contiguous().view(-1) loss = F.nll_loss( lprobs, target, ignore_index=padding_idx, reduction='sum' if reduce else 'none', ) return loss, loss
def get_ctc_output( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], sample: Optional[Dict[str, Tensor]]): encoder_out = net_output[1]["encoder_out"]["encoder_out"][0] logits = self.encoder.ctc_proj(encoder_out) # T x B x C out = utils.log_softmax(logits.float(), dim=-1) padding_mask = net_output[1]["encoder_out"]["encoder_padding_mask"] lens = out.new_full((out.shape[1], ), out.shape[0]).long() if len(padding_mask) > 0: lens -= padding_mask[0].sum(dim=-1) return out, lens
def compute_loss(self, net_output, sample, reduce=True): lprobs = utils.log_softmax(net_output, dim=-1) lprobs = lprobs.view(-1, lprobs.size(-1)) target = sample['target'].view(-1, 1) non_pad_mask = target.ne(self.padding_idx) nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] if reduce: nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() eps_i = self.eps / lprobs.size(-1) loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss return loss, nll_loss
def forward(self, sample, encoder_out, decoder_out): x = decoder_out[1] # Final LN if self.final_ln is not None: x = self.final_ln(x) # Aggregate x, _ = self.aggregator(x) # A stack of linear layers x = self.net(x) # The R/W decisions: x = self.gate(x) s = F.logsigmoid(x) RWlogits = torch.cat((s, s - x), dim=-1).float() lprobs = decoder_out[0] target = sample['target'] encoder_mask = encoder_out['encoder_padding_mask'] decoder_mask = decoder_out[2] with torch.no_grad(): # Gather the ground truth likelihoods B, Tt, Ts, V = lprobs.size() lprobs = utils.log_softmax(lprobs, dim=-1) scores = lprobs.view(-1, V).gather( dim=-1, index=target.unsqueeze(-1).expand(-1, -1, Ts).contiguous().view( -1, 1) # BTtTs ).view(B, Tt, Ts) # Forbid padding positions: # I'm using NLL beware if encoder_mask is not None: scores = scores.masked_fill(encoder_mask.unsqueeze(1), -1000) if decoder_mask is not None: scores = scores.masked_fill(decoder_mask.unsqueeze(-1), -1000) # The Oracle best_context = self.oracle(scores) # AP = best_context.add(1).float().mean(dim=1) / Ts # print('AP:', ' '.join(map(lambda x: '{:.2f}'.format(x), AP.tolist()))) Gamma = torch.zeros_like(scores).scatter_( -1, best_context.unsqueeze(-1), 1.0) # B, Tt, Ts # Write beyond the ideal context if self.write_right: Gamma = Gamma.cumsum(dim=-1) write = Gamma[:, 1:] # B, Tt-1, Ts else: write = Gamma[:, 1:].cumsum(dim=-1) # B, Tt-1, Ts read = 1 - write return Gamma, RWlogits[:, :-1], read, write
def get_normalized_probs(self, net_output, log_probs, sample): """Get normalized probabilities (or log probs) from a net's output.""" # print('enter normalized.') if 'net_input' in sample.keys(): enc_seq_ids = sample['net_input']['src_tokens'] else: enc_seq_ids = sample['src_tokens'] # wvocab_size = net_output[0].size(2) # batch_size = enc_seq_ids.size(0) # seq_len = enc_seq_ids.size(1) # one_hot = torch.zeros(batch_size, seq_len, wvocab_size).cuda().scatter_(dim=2, index=enc_seq_ids.unsqueeze(-1), value=1) # # copy_probs = torch.matmul(net_output[1]['attn'], one_hot) # final_dist = vocab_dist.scatter_add(1, encoder_batch_extend_vocab, attn_dist) if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: if sample is not None: assert 'target' in sample target = sample['target'] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: generate = utils.softmax( logits, dim=-1, onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate'] copy = net_output[1]['attn'] * (1 - net_output[1]['copy_or_generate']) enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat( 1, net_output[1]['attn'].size(1), 1) final = generate.scatter_add(2, enc_seq_ids, copy) final = torch.log(final + 1e-15) return final else: generate = utils.log_softmax( logits, dim=-1, onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate'] copy = net_output[1]['attn'] * (1 - net_output[1]['copy_or_generate']) enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat( 1, net_output[1]['attn'].size(1), 1) final = generate.scatter_add(2, enc_seq_ids, copy) return final
def sesim_loss(lprobs, target, epsilon, task=None, bpe=None, rewarder=None, output_tokens=None, ignore_index=None, reduce=True, loss_weight=None, debug=True): if loss_weight is None: loss_weight = -15 ## semantic sim_loss sentence_tok = torch.argmax(utils.log_softmax(output_tokens, dim=-1),-1) # maxpool sentence_txt = bpe.decode(task.target_dictionary.string(sentence_tok)) if ignore_index is not None: non_pad_mask = target.ne(ignore_index) target_ig=target[non_pad_mask] target_txt = bpe.decode(task.target_dictionary.string(target_ig)) semsim_score = rewarder(target_txt, sentence_txt) if debug: print("\n\n## sentence_txt: ", sentence_txt,"\n## target_txt: ", target_txt, "\n## Reward :", semsim_score) if target.dim() == lprobs.dim() - 1: target = target.unsqueeze(-1) nll_loss = -lprobs.gather(dim=-1, index=target) smooth_loss = -lprobs.sum(dim=-1, keepdim=True) if ignore_index is not None: non_pad_mask = target.ne(ignore_index) nll_loss = nll_loss[non_pad_mask] smooth_loss = smooth_loss[non_pad_mask] else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) if reduce: nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() eps_i = epsilon / lprobs.size(-1) loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss if debug: print("nll_loss, smooth_loss: ", nll_loss, smooth_loss) print("normal_loss, reward: ", loss, semsim_score) print('loss before:') print(loss) loss = loss * (1 - semsim_score) print('loss: ') print(loss) #loss = loss - loss_weight * semsim_score # LOG : loss # was 1:1, increased to 1: 100 | 20191212 # original : loss + 100*semsim_score, neg : loss - 100*semsim_score | 20191212 if debug: print("==="*10) return loss, nll_loss, semsim_score # semsim_score : semsim_score