def forward_att(self, eouts, elens, ys): """Compute XE loss for the sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (list): A list of length `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` acc (float): ppl (float): """ bs = eouts.size(0) # Append <sos> and <eos> eos = eouts.new_zeros((1,)).fill_(self.eos).long() ylens = [len(y) for y in ys] ys = [np2tensor(np.fromiter(y[::-1] if self.backward else y, dtype=np.int64), self.device_id).long() for y in ys] ys_in = [torch.cat([eos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.pad) ys_out_pad = pad_list(ys_out, self.pad) # Add positional embedding ys_emb = self.embed(ys_in_pad) * (self.d_model ** 0.5) if self.pe_type: ys_emb = self.pos_emb_out(ys_emb) for l in range(self.n_layers): ys_emb, yy_aw, xy_aw = self.layers[l](eouts, elens, ys_emb, ylens) logits = self.norm_top(ys_emb) if self.adaptive_softmax is None: logits = self.output(logits) # Compute XE sequence loss if self.adaptive_softmax is None: if self.lsm_prob > 0: # Label smoothing loss = cross_entropy_lsm(logits, ys_out_pad, ylens=[y.size(0) for y in ys_out], lsm_prob=self.lsm_prob, size_average=False) / bs else: loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), ignore_index=self.pad, size_average=False) / bs else: loss = self.adaptive_softmax(logits.view((-1, logits.size(2))), ys_out_pad.view(-1)).loss # Compute token-level accuracy in teacher-forcing if self.adaptive_softmax is None: acc = compute_accuracy(logits, ys_out_pad, pad=self.pad) else: acc = compute_accuracy(self.adaptive_softmax.log_prob( logits.view((-1, logits.size(2)))), ys_out_pad, pad=self.pad) ppl = min(np.exp(loss.item()), np.inf) return loss, acc, ppl
def forward_att(self, eouts, elens, ys, trigger_points=None): """Compute XE loss for the Transformer decoder. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (List): length `[B]`, each of which contains a list of size `[L]` trigger_points (IntTensor): `[B, L]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity losses_auxiliary (dict): """ losses_auxiliary = {} # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(ys, self.eos, self.eos, self.pad, self.device, self.bwd) if not self.training: self.data_dict['elens'] = tensor2np(elens) self.data_dict['ylens'] = tensor2np(ylens) self.data_dict['ys'] = tensor2np(ys_out) # Create target self-attention mask bs, ymax = ys_in.size()[:2] tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1]) causal_mask = tgt_mask.new_ones(ymax, ymax, dtype=tgt_mask.dtype) causal_mask = torch.tril(causal_mask).unsqueeze(0) tgt_mask = tgt_mask & causal_mask # `[B, L (query), L (key)]` # Create source-target mask src_mask = make_pad_mask(elens.to(self.device)).unsqueeze(1).repeat([1, ymax, 1]) # `[B, L, T]` # Create attention padding mask for quantity loss if self.attn_type == 'mocha': attn_mask = (ys_out != self.pad).unsqueeze(1).unsqueeze(3) # `[B, 1, L, 1]` else: attn_mask = None # external LM integration lmout = None if self.lm is not None: self.lm.eval() with torch.no_grad(): lmout, lmstate, _ = self.lm.predict(ys_in, None) lmout = self.lm_output_proj(lmout) out = self.pos_enc(self.embed_token_id(ys_in), scale=True) # scaled + dropout xy_aws_layers = [] xy_aws = None for lth, layer in enumerate(self.layers): out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout) # Attention padding xy_aws = layer.xy_aws if xy_aws is not None and self.attn_type == 'mocha': xy_aws_masked = xy_aws.masked_fill_(attn_mask.expand_as(xy_aws) == 0, 0) # NOTE: attention padding is quite effective for quantity loss xy_aws_layers.append(xy_aws_masked.clone()) if not self.training: self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws) self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws) self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta) self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose) self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm) logits = self.output(self.norm_out(out)) # Compute XE loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) # Quantity loss losses_auxiliary['loss_quantity'] = 0. if self.attn_type == 'mocha': # Average over all heads across all layers n_tokens_ref = tgt_mask[:, -1, :].sum(1).float() # `[B]` # NOTE: count <eos> tokens n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1)) for aws in xy_aws_layers]) # `[B]` n_tokens_pred /= len(xy_aws_layers) losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref)) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl, losses_auxiliary
def forward_att(self, eouts, elens, ys, return_logits=False, teacher_logits=None, trigger_points=None): """Compute XE loss for the Transformer decoder. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation teacher_logits (FloatTensor): `[B, L, vocab]` trigger_points (IntTensor): `[B, T]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity loss_quantity (FloatTensor): `[1]` loss_headdiv (FloatTensor): `[1]` loss_latency (FloatTensor): `[1]` """ # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos, self.pad, self.bwd) if not self.training: self.data_dict['elens'] = tensor2np(elens) self.data_dict['ylens'] = tensor2np(ylens) self.data_dict['ys'] = tensor2np(ys_out) # Create target self-attention mask xmax = eouts.size(1) bs, ymax = ys_in.size()[:2] mlen = 0 tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1]) causal_mask = tgt_mask.new_ones(ymax, ymax).byte() causal_mask = torch.tril(causal_mask, diagonal=0 + mlen, out=causal_mask).unsqueeze(0) tgt_mask = tgt_mask & causal_mask # `[B, L (query), L (key)]` # Create source-target mask src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat([1, ymax, 1]) # `[B, L, T]` # external LM integration lmout = None if self.lm is not None: self.lm.eval() with torch.no_grad(): lmout, lmstate, _ = self.lm.predict(ys_in, None) lmout = self.lm_output_proj(lmout) out = self.pos_enc(self.embed(ys_in)) # scaled mems = self.init_memory() pos_embs = None if self.memory_transformer: out = self.dropout_emb(out) # NOTE: TransformerXL does not use positional encoding in the token embedding # adopt zero-centered offset pos_idxs = torch.arange(mlen - 1, -ymax - 1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) hidden_states = [out] xy_aws_layers = [] for lth, (mem, layer) in enumerate(zip(mems, self.layers)): out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout, pos_embs=pos_embs, memory=mem, u=self.u, v=self.v) if lth < self.n_layers - 1: hidden_states.append(out) # NOTE: outputs from the last layer is not used for momory # Attention padding xy_aws = layer.xy_aws if xy_aws is not None and 'mocha' in self.attn_type: tgt_mask_v2 = (ys_out != self.pad).unsqueeze(1).unsqueeze(3) # `[B, 1, L, 1]` xy_aws = xy_aws.masked_fill_(tgt_mask_v2.repeat([1, xy_aws.size(1), 1, xmax]) == 0, 0) # NOTE: attention padding is quite effective for quantity loss xy_aws_layers.append(xy_aws.clone()) if not self.training: if layer.yy_aws is not None: self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws) if layer.xy_aws is not None: self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws) if layer.xy_aws_beta is not None: self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta) if layer.xy_aws_p_choose is not None: self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose) if layer.yy_aws_lm is not None: self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm) logits = self.output(self.norm_out(out)) # for knowledge distillation if return_logits: return logits # Compute XE loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) losses_auxiliary = {} # Quantity loss losses_auxiliary['loss_quantity'] = 0. if 'mocha' in self.attn_type: # Average over all heads across all layers n_tokens_ref = tgt_mask[:, -1, :].sum(1).float() # `[B]` # NOTE: count <eos> tokens n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1)) for aws in xy_aws_layers]) # `[B]` n_tokens_pred /= len(xy_aws_layers) losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref)) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl, losses_auxiliary
def forward_att(self, eouts, elens, ys): """Compute XE loss for the sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (list): A list of length `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` acc (float): ppl (float): """ # Append <sos> and <eos> eos = eouts.new_zeros((1, )).fill_(self.eos).long() ys = [ np2tensor( np.fromiter(y[::-1] if self.backward else y, dtype=np.int64), self.device_id).long() for y in ys ] ys_in = [torch.cat([eos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.pad) ys_out_pad = pad_list(ys_out, -1) # Add positional embedding ys_emb = self.embed(ys_in_pad) * (self.d_model**0.5) if self.pe_type: ys_emb = self.pos_emb_out(ys_emb) # Make source-target attention mask: `[B, L(query), T(key)]` bs, max_xlen = eouts.size()[:2] y_len_max = ys_in_pad.size(1) yx_mask = (ys_in_pad != self.pad).unsqueeze(-1).expand( bs, y_len_max, max_xlen) for b in range(bs): if elens[b] < max_xlen: yx_mask[b, :, elens[b]:] = 0 # Make target-side self-attention mask (hide future tokens): `[B, L(query), L(key)]` yy_mask = (ys_in_pad != self.pad).unsqueeze(-2).expand( bs, y_len_max, y_len_max) history_mask = torch.triu(torch.ones((y_len_max, y_len_max), device=self.device_id, dtype=torch.uint8), diagonal=1) history_mask = history_mask.unsqueeze(0).expand(bs, -1, -1) == 0 yy_mask = yy_mask & history_mask for l in range(self.n_layers): ys_emb, yy_aw, xy_aw = self.layers[l](eouts, ys_emb, yx_mask, yy_mask) ys_emb = self.layer_norm_top(ys_emb) logits = self.output(ys_emb) # Compute XE sequence loss if self.lsm_prob > 0: # Label smoothing loss = cross_entropy_lsm(logits, ys=ys_out_pad, ylens=[y.size(0) for y in ys_out], lsm_prob=self.lsm_prob, size_average=True) else: loss = F.cross_entropy( input=logits.view((-1, logits.size(2))), target=ys_out_pad.view(-1), # long ignore_index=-1, size_average=False) / bs # Compute token-level accuracy in teacher-forcing pad_pred = logits.view(ys_out_pad.size(0), ys_out_pad.size(1), logits.size(-1)).argmax(2) mask = ys_out_pad != -1 numerator = (pad_pred.masked_select(mask) == ys_out_pad.masked_select( mask)).sum() denominator = mask.sum() acc = float(numerator) * 100 / float(denominator) ppl = np.exp(loss.item()) return loss, acc, ppl
def _forward(self, ys, state, reporter, n_caches=0, predict_last=False): ys = [np2tensor(y, self.device_id) for y in ys] # <eos> is included ylens = np2tensor( np.fromiter([y.size(0) - 1 for y in ys], dtype=np.int32)) # -1 for <eos> ys = pad_list(ys, self.pad) ys_in = ys[:, :-1] ys_out = ys[:, 1:] out, state = self.decode(ys_in, state) if self.adaptive_softmax is None: logits = self.output(out) else: logits = out if predict_last: ys_out = ys_out[:, -1].unsqueeze(1) logits = logits[:, -1].unsqueeze(1) # Compute XE sequence loss if n_caches > 0 and len(self.cache_ids) > 0: assert ys_out.size(1) == 1 assert ys_out.size(0) == 1 if self.adaptive_softmax is None: probs = F.softmax(logits, dim=-1) else: probs = self.adaptive_softmax.log_prob(logits).exp() cache_probs = probs.new_zeros(probs.size()) # Truncate cache self.cache_ids = self.cache_ids[-n_caches:] # list of `[B, 1]` self.cache_keys = self.cache_keys[ -n_caches:] # list of `[B, 1, n_units]` # Compute inner-product over caches cache_attn = F.softmax( self.cache_theta * torch.matmul(torch.cat(self.cache_keys, dim=1), out.transpose(2, 1)).squeeze(2), dim=1) # For visualization if len(self.cache_ids) == n_caches: self.cache_attn += [cache_attn.cpu().numpy()] self.cache_attn = self.cache_attn[-n_caches:] # Sum all probabilities for offset, idx in enumerate(self.cache_ids): cache_probs[:, :, idx] += cache_attn[:, offset] probs = (1 - self.cache_lambda ) * probs + self.cache_lambda * cache_probs loss = -torch.log(probs[:, :, ys_out[:, -1]]) else: if self.adaptive_softmax is None: if self.lsm_prob > 0 and self.training: # Label smoothing loss = cross_entropy_lsm(logits.view((-1, logits.size(2))), ys_out.contiguous().view(-1), self.lsm_prob, self.pad) else: loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out.contiguous().view(-1), ignore_index=self.pad, size_average=True) else: loss = self.adaptive_softmax(logits.view((-1, logits.size(2))), ys_out.contiguous().view(-1)).loss if n_caches > 0: # Register to cache self.cache_ids += [ys_out[0, -1].item()] self.cache_keys += [out] # Compute token-level accuracy in teacher-forcing if self.adaptive_softmax is None: acc = compute_accuracy(logits, ys_out, pad=self.pad) else: acc = compute_accuracy(self.adaptive_softmax.log_prob( logits.view((-1, logits.size(2)))), ys_out, pad=self.pad) observation = { 'loss.lm': loss.item(), 'acc.lm': acc, 'ppl.lm': min(np.exp(loss.item()), np.inf) } # Report here if reporter is not None: is_eval = not self.training reporter.add(observation, is_eval) return loss, state, reporter
def forward_att(self, eouts, elens, ys, return_logits=False): """Compute XE loss for the sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation Returns: loss (FloatTensor): `[1]` acc (float): ppl (float): """ bs = eouts.size(0) # Append <sos> and <eos> eos = eouts.new_zeros(1).fill_(self.eos).long() ys = [ np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int64), self.device_id) for y in ys ] ylens = np2tensor( np.fromiter([y.size(0) + 1 for y in ys], dtype=np.int32)) # +1 for <eos> ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys], self.pad) ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys], self.pad) # Create the self-attention mask bs, ymax = ys_in_pad.size()[:2] yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand( bs, ymax, ymax) yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax, ymax) subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand( bs, self.attn_n_heads, ymax, ymax) yy_mask = yy_mask & subsequent_mask # Create the source-target mask xmax = eouts.size(1) x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand( bs, ymax, xmax) y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).expand( bs, ymax, xmax) xy_mask = (x_mask * y_mask).unsqueeze(1).expand( bs, self.attn_n_heads, ymax, xmax) ys_emb = self.pos_enc(self.embed(ys_in_pad)) for l in range(self.n_layers): ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts, xy_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws)) logits = self.norm_out(ys_emb) if self.adaptive_softmax is None: logits = self.output(logits) if return_logits: return logits # Compute XE sequence loss if self.adaptive_softmax is None: if self.lsm_prob > 0 and self.training: # Label smoothing loss = cross_entropy_lsm(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), self.lsm_prob, self.pad) else: loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), ignore_index=self.pad, size_average=True) # Focal loss if self.focal_loss_weight > 0: fl = focal_loss(logits, ys_out_pad, ylens, alpha=self.focal_loss_weight, gamma=self.focal_loss_gamma) loss = loss * ( 1 - self.focal_loss_weight) + fl * self.focal_loss_weight else: loss = self.adaptive_softmax(logits.view((-1, logits.size(2))), ys_out_pad.view(-1)).loss # Compute token-level accuracy in teacher-forcing if self.adaptive_softmax is None: acc = compute_accuracy(logits, ys_out_pad, self.pad) else: acc = compute_accuracy( self.adaptive_softmax.log_prob( logits.view((-1, logits.size(2)))), ys_out_pad, self.pad) ppl = min(np.exp(loss.item()), np.inf) # scale loss for CTC loss *= ylens.float().mean() return loss, acc, ppl
def forward_att(self, eouts, elens, ys, device_id): """Compute XE loss for the sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, dec_units]` elens (list): A list of length `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` device_id (int): Returns: loss (FloatTensor): `[B, L, vocab]` acc (float): ppl (float): """ bs, _, enc_nunits = eouts.size() # Append <sos> and <eos> sos = eouts.new_zeros(1).fill_(self.sos).long() eos = eouts.new_zeros(1).fill_(self.eos).long() if self.backward: ys = [ np2tensor(np.fromiter(y[::-1], dtype=np.int64), device_id).long() for y in ys ] ys_in = [torch.cat([eos, y], dim=0) for y in ys] ys_out = [torch.cat([y, sos], dim=0) for y in ys] else: ys = [ np2tensor(np.fromiter(y, dtype=np.int64), device_id).long() for y in ys ] ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.pad) ys_out_pad = pad_list(ys_out, -1) # Initialization dout, dstate = self.init_dec_state(bs, self.nlayers, device_id, eouts, elens) _dout, _dstate = self.init_dec_state(bs, 1, device_id, eouts, elens) # for internal LM context = eouts.new_zeros(bs, 1, enc_nunits) self.score.reset() aw = None rnnlm_state = None # Pre-computation of embedding ys_emb = self.embed(ys_in_pad) if self.rnnlm_cf: ys_lm_emb = self.rnnlm_cf.embed(ys_in_pad) # ys_lm_emb = [self.rnnlm_cf.embed(ys_in_pad[:, t:t + 1]) # for t in range(ys_in_pad.size(1))] # ys_lm_emb = torch.cat(ys_lm_emb, dim=1) logits = [] for t in range(ys_in_pad.size(1)): # Sample for scheduled sampling is_sample = t > 0 and self.ss_prob > 0 and random.random( ) < self.ss_prob if is_sample: y_emb = self.embed(torch.argmax(logits[-1].detach(), dim=-1)) else: y_emb = ys_emb[:, t:t + 1] # Recurrency dout, dstate, _dout, _dstate = self.recurrency( y_emb, context, dstate, _dstate) # Update RNNLM states for cold fusion if self.rnnlm_cf: if is_sample: y_lm_emb = self.rnnlm_cf.embed( np.argmax(logits[-1].detach(), axis=2).cuda(device_id)) else: y_lm_emb = ys_lm_emb[:, t:t + 1] logits_lm_t, lm_out, rnnlm_state = self.rnnlm_cf.predict( y_lm_emb, rnnlm_state) else: logits_lm_t, lm_out = None, None # Score context, aw = self.score(eouts, elens, dout, aw) # Generate attentional_t = self.generate(context, dout, logits_lm_t, lm_out) if self.rnnlm_init and self.internal_lm: # Residual connection attentional_t += _dout logits.append(self.output(attentional_t)) # Compute XE sequence loss logits = torch.cat(logits, dim=1) / self.logits_temp if self.lsm_prob > 0: # Label smoothing y_lens = [y.size(0) for y in ys_out] loss = cross_entropy_lsm(logits, ys=ys_out_pad, y_lens=y_lens, lsm_prob=self.lsm_prob, size_average=True) else: loss = F.cross_entropy( logits.view((-1, logits.size(2))), ys_out_pad.view(-1), # long ignore_index=-1, size_average=False) / bs ppl = math.exp(loss.item()) # Focal loss if self.fl_weight > 0: y_lens = [y.size(0) for y in ys_out] fl = focal_loss(logits, ys=ys_out_pad, y_lens=y_lens, gamma=self.fl_gamma, size_average=True) loss = loss * (1 - self.fl_weight) + fl * self.fl_weight # Compute token-level accuracy in teacher-forcing pad_pred = logits.view(ys_out_pad.size(0), ys_out_pad.size(1), logits.size(-1)).argmax(2) mask = ys_out_pad != -1 numerator = torch.sum( pad_pred.masked_select(mask) == ys_out_pad.masked_select(mask)) denominator = torch.sum(mask) acc = float(numerator) * 100 / float(denominator) return loss, acc, ppl
def forward_att(self, eouts, elens, ys): """Compute XE loss for the CIF model. Args: eouts (FloatTensor): `[B, T, dec_n_units]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity """ bs, xmax = eouts.size()[:2] # Append <sos> and <eos> ys_in_pad, ys_out_pad, ylens = self.append_sos_eos(ys, self.bwd) # Initialization dstate = self.zero_state(bs) lmouts, lmstate = None, None # CIF cvs, alpha, aws = self.score(eouts, elens, ylens) # Update LM states for LM fusion if self.lm is not None: lmouts, _ = self.lm.decode(ys_in_pad, lmstate) # Recurrency -> Score -> Generate ys_emb = self.embed(ys_in_pad) dstate, logits = self.decode_step(cvs, dstate, ys_emb, lmouts) logits = self.output(logits) # for attention plot if not self.training: self.aws = tensor2np(aws) # `[B, n_heads, L, T]` # Compute XE sequence loss if self.lsm_prob > 0 and self.training: # Label smoothing loss = cross_entropy_lsm(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), self.lsm_prob, self.pad) else: loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), ignore_index=self.pad, size_average=True) # Quantity loss for CIF loss += torch.mean( torch.abs(alpha.sum(1) - ylens.float().cuda(self.device_id)) ) * self.quantity_loss_weight # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out_pad, self.pad) ppl = np.exp(loss.item()) # scale loss for CTC loss *= ylens.float().mean() return loss, acc, ppl
def _forward(self, ys, state, n_caches=0, predict_last=False): ys = [np2tensor(y, self.device) for y in ys] # <eos> is included ys = pad_list(ys, self.pad) ys_in, ys_out = ys[:, :-1], ys[:, 1:] logits, out, new_state = self.decode(ys_in, state=state, mems=state) # NOTE: state=state is used for RNNLM while mems=state is used for TransformerXL. # TransformerLM ignores both of them. if predict_last: ys_out = ys_out[:, -1].unsqueeze(1) logits = logits[:, -1].unsqueeze(1) # Compute XE sequence loss if n_caches > 0 and len(self.cache_ids) > 0: assert ys_out.size(1) == 1 assert ys_out.size(0) == 1 if self.adaptive_softmax is None: probs = torch.softmax(logits, dim=-1) else: probs = self.adaptive_softmax.log_prob(logits).exp() cache_probs = probs.new_zeros(probs.size()) # Truncate cache self.cache_ids = self.cache_ids[-n_caches:] # list of `[B, 1]` self.cache_keys = self.cache_keys[ -n_caches:] # list of `[B, 1, n_units]` # Compute inner-product over caches cache_attn = torch.softmax( self.cache_theta * torch.matmul(torch.cat(self.cache_keys, dim=1), out.transpose(2, 1)).squeeze(2), dim=1) # For visualization if len(self.cache_ids) == n_caches: self.cache_attn += [cache_attn.cpu().numpy()] self.cache_attn = self.cache_attn[-n_caches:] # Sum all probabilities for offset, idx in enumerate(self.cache_ids): cache_probs[:, :, idx] += cache_attn[:, offset] probs = (1 - self.cache_lambda ) * probs + self.cache_lambda * cache_probs loss = -torch.log(probs[:, :, ys_out[:, -1]]) else: if self.adaptive_softmax is None: loss, ppl = cross_entropy_lsm(logits, ys_out.contiguous(), self.lsm_prob, self.pad, self.training, normalize_length=True) else: loss = self.adaptive_softmax(logits.view((-1, logits.size(2))), ys_out.contiguous().view(-1)).loss ppl = np.exp(loss.item()) if n_caches > 0: # Register to cache self.cache_ids += [ys_out[0, -1].item()] self.cache_keys += [out] # Compute token-level accuracy in teacher-forcing if self.adaptive_softmax is None: acc = compute_accuracy(logits, ys_out, pad=self.pad) else: acc = compute_accuracy(self.adaptive_softmax.log_prob( logits.view((-1, logits.size(2)))), ys_out, pad=self.pad) observation = {'loss.lm': loss.item(), 'acc.lm': acc, 'ppl.lm': ppl} return loss, new_state, observation
def forward(self, enc_out, enc_lens, ys): """Decoding in the training stage. Compute XE loss. Args: enc_out (torch.autograd.Variable, float): A tensor of size `[B, T, enc_num_units]` enc_lens (list): A list of length `[B]` ys (list): A list of length `[B]`, which contains Variables of size `[L]` Returns: logits (torch.autograd.Variable, float): A tensor of size `[B, L, num_classes]` aw (torch.autograd.Variable, float): A tensor of size `[B, L, T, num_heads]` logits_lm (torch.autograd.Variable, float): A tensor of size `[B, L, num_classes]` """ # Compute the auxiliary CTC loss if self.ctc_weight > 0: logits_ctc = self.output_ctc(enc_out) loss_ctc = self._compute_ctc_loss(logits_ctc, enc_lens, ys) device_id = enc_out.get_device() if device_id >= 0: loss_ctc = loss_ctc.cuda(device_id) loss = loss_ctc * self.ctc_weight else: loss_ctc = 0 loss = 0. if self.ctc_weight == 1: loss_acc = { 'loss': loss, 'loss_att': 0, 'loss_ctc': loss_ctc, 'loss_lm': 0, 'acc': 0 } return loss_acc # Reverse the order if self.backward: ys = [y[::-1] for y in ys] # Append <sos> and <eos> sos = Variable(enc_out.data.new(1, ).fill_(self.sos).long()) eos = Variable(enc_out.data.new(1, ).fill_(self.eos).long()) ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.pad) ys_out_pad = pad_list(ys_out, -1) # Initialization dec_out, dec_state = self._init_dec_state(enc_out, enc_lens, self.num_layers) _dec_out, _dec_state = self._init_dec_state(enc_out, enc_lens, 1) # for internal LM self.score.reset() aw_t = None rnnlm_state = None # Pre-computation of embedding ys_emb = self.emb(ys_in_pad) if self.rnnlm_cf is not None: ys_lm_emb = [ self.rnnlm_cf.emb(ys_in_pad[:, t:t + 1]) for t in six.moves.range(ys_in_pad.size(1)) ] ys_lm_emb = torch.cat(ys_lm_emb, dim=1) logits_att, logits_lm = [], [] for t in six.moves.range(ys_in_pad.size(1)): is_sample = t > 0 and self.ss_prob > 0 and random.random( ) < self.ss_prob # Score cv, aw_t = self.score(enc_out, enc_lens, dec_out, aw_t) # Update RNNLM states for cold fusion if self.rnnlm_cf is not None: if is_sample: device_id = logits_att[-1].get_device() y_lm_emb = self.rnnlm_cf.emb( np.argmax(logits_att[-1].detach(), axis=2).cuda(device_id)) else: y_lm_emb = ys_lm_emb[:, t:t + 1] logits_lm_t, lm_out, rnnlm_state = self.rnnlm_cf.predict( y_lm_emb, rnnlm_state) else: logits_lm_t, lm_out = None, None # Generate logits_att_t = self._generate(cv, dec_out, logits_lm_t, lm_out) # Residual connection if self.rnnlm_init is not None and self.internal_lm: logits_att_t += _dec_out if self.share_softmax or self.rnnlm_init is not None: logits_att_t = self.output_bottle(logits_att_t) logits_att_t = self.output(logits_att_t) if self.rnnlm_cf is not None: logits_att_t = F.relu(logits_att_t) logits_att.append(logits_att_t) if t == ys_in_pad.size(1) - 1: break # Sample for scheduled sampling if is_sample: device_id = logits_att[-1].get_device() y_emb = self.emb( np.argmax(logits_att[-1].detach(), axis=2).cuda(device_id)) else: y_emb = ys_emb[:, t + 1:t + 2] # Recurrency dec_out, dec_state, _dec_out, _dec_state = self._recurrency( y_emb, cv, dec_state, _dec_state) if self.rnnlm_weight > 0: if self.share_softmax: logits_lm_t = self.output(_dec_out) else: logits_lm_t = self.output_rnnlm(_dec_out) logits_lm.append(logits_lm_t) logits_att = torch.cat(logits_att, dim=1) / self.logits_temp # Compute XE sequence loss if self.lsm_prob > 0: # Label smoothing y_lens = [y.size(0) for y in ys_out] loss_att = cross_entropy_lsm(logits_att, ys=ys_out_pad, y_lens=y_lens, lsm_prob=self.lsm_prob, lsm_type=self.lsm_type, size_average=True) else: loss_att = F.cross_entropy( input=logits_att.view((-1, logits_att.size(2))), target=ys_out_pad.view(-1), # long ignore_index=-1, size_average=False) / len(enc_out) loss += loss_att * (1 - self.ctc_weight) # Compute XE loss for RNNLM objective if self.rnnlm_weight > 0: logits_lm = torch.cat(logits_lm, dim=1) loss_lm = F.cross_entropy( input=logits_lm.view((-1, logits_lm.size(2))), target=ys_out_pad[:, 1:].contiguous().view(-1), ignore_index=-1, size_average=True) loss += loss_lm * self.rnnlm_weight else: loss_lm = 0 # Compute token-level accuracy in teacher-forcing pad_pred = logits_att.data.view(ys_out_pad.size(0), ys_out_pad.size(1), logits_att.size(-1)).max(2)[1] mask = ys_out_pad.data != -1 numerator = torch.sum( pad_pred.masked_select(mask) == ys_out_pad.data.masked_select( mask)) denominator = torch.sum(mask) acc = float(numerator) / float(denominator) loss_acc = { 'loss': loss, 'loss_att': loss_att, 'loss_ctc': loss_ctc, 'loss_lm': loss_lm, 'acc': acc } return loss_acc
def forward_att(self, eouts, elens, ys, ys_hist=[], return_logits=False, teacher_logits=None): """Compute XE loss for the Transformer model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` ys_hist (list): return_logits (bool): return logits for knowledge distillation teacher_logits (FloatTensor): `[B, L, vocab]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity """ bs = eouts.size(0) # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.pad, self.bwd) # Create the self-attention mask bs, ytime = ys_in.size()[:2] tgt_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat( [1, ytime, 1]) subsequent_mask = tgt_mask.new_ones(ytime, ytime).byte() subsequent_mask = torch.tril(subsequent_mask, out=subsequent_mask).unsqueeze(0) tgt_mask = tgt_mask & subsequent_mask # Create the source-target mask src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat( [1, ytime, 1]) out = self.pos_enc(self.embed(ys_in)) for l in range(self.n_layers): out, yy_aws, xy_aws = self.layers[l](out, tgt_mask, eouts, src_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws)) logits = self.output(self.norm_out(out)) # for knowledge distillation if return_logits: return logits # Compute XE sequence loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl
def forward_att(self, eouts, elens, ys, return_logits=False, teacher_logits=None, trigger_points=None): """Compute XE loss for the Transformer decoder. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation teacher_logits (FloatTensor): `[B, L, vocab]` trigger_points (IntTensor): `[B, T]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity loss_quantity (FloatTensor): `[1]` loss_headdiv (FloatTensor): `[1]` loss_latency (FloatTensor): `[1]` """ # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos, self.pad, self.bwd) if not self.training: self.data_dict['elens'] = tensor2np(elens) self.data_dict['ylens'] = tensor2np(ylens) self.data_dict['ys'] = tensor2np(ys_out) # Create target self-attention mask xtime = eouts.size(1) bs, ymax = ys_in.size()[:2] tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1]) causal_mask = tgt_mask.new_ones(ymax, ymax).byte() causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0) tgt_mask = tgt_mask & causal_mask # `[B, L, L]` # Create source-target mask src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat( [1, ymax, 1]) # `[B, L, T]` # external LM integration lmout = None if self.lm is not None: self.lm.eval() with torch.no_grad(): lmout, lmstate, _ = self.lm.predict(ys_in, None) lmout = self.lm_output_proj(lmout) out = self.embed(ys_in) mlen = 0 # TODO: fix later if self.memory_transformer: # NOTE: TransformerXL does not use positional encoding in the token embedding mems = self.init_memory() # adopt zero-centered offset pos_idxs = torch.arange(mlen - 1, -ymax - 1, -1.0, dtype=torch.float) if self.device_id >= 0: pos_idxs = pos_idxs.cuda(self.device_id) pos_embs = self.dropout_emb(self.pos_emb(pos_idxs)) out = self.dropout_emb(out) hidden_states = [out] else: out = self.pos_enc(out) xy_aws_layers = [] for l, layer in enumerate(self.layers): if self.memory_transformer: out, yy_aws, xy_aws, xy_aws_beta, yy_aws_lm = layer( out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout, pos_embs=pos_embs, memory=mems[l], u=self.u, v=self.v) hidden_states.append(out) else: out, yy_aws, xy_aws, xy_aws_beta, yy_aws_lm = layer( out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout) xy_aws_layers.append(xy_aws.clone() if xy_aws is not None else out. new_zeros(bs, yy_aws.size(1), ymax, xtime)) if not self.training: if yy_aws is not None: self.aws_dict['yy_aws_layer%d' % l] = tensor2np(yy_aws) if xy_aws is not None: self.aws_dict['xy_aws_layer%d' % l] = tensor2np(xy_aws) if xy_aws_beta is not None: self.aws_dict['xy_aws_beta_layer%d' % l] = tensor2np(xy_aws_beta) if yy_aws_lm is not None: self.aws_dict['yy_aws_lm_layer%d' % l] = tensor2np(yy_aws_lm) logits = self.output(self.norm_out(out)) # TODO: Update memory # if self.memory_transformer: # new_mems = self.update_memory(mems, hidden_states) # for knowledge distillation if return_logits: return logits # Compute XE sequence loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) # Attention padding if self.quantity_loss_weight > 0 or self.headdiv_loss_weight > 0 or self.latency_loss_weight > 0: for l in range(self.mocha_first_layer - 1, self.n_layers): n_heads = xy_aws_layers[l].size(1) xy_aws_layers[l] = xy_aws_layers[l].masked_fill_( src_mask.unsqueeze(1).repeat([1, n_heads, 1, 1]) == 0, 0) xy_aws_layers[l] = xy_aws_layers[l].masked_fill_( tgt_mask[:, :, -1:].unsqueeze(1).repeat([1, n_heads, 1, xtime]) == 0, 0) # NOTE: attention padding is quite effective for quantity loss n_heads = xy_aws_layers[-1].size(1) # mono # NOTE: debug for multihead mono + multihead chunk # Quantity loss loss_quantity = 0. if 'mocha' in self.attn_type: # Average over all heads across all layers n_tokens_ref = tgt_mask[:, -1, :].sum(1).float() # `[B]` # NOTE: count <eos> tokens n_tokens_pred = sum([ torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1)) for aws in xy_aws_layers[self.mocha_first_layer - 1:] ]) # `[B]` n_tokens_pred /= (self.n_layers - self.mocha_first_layer + 1) loss_quantity = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref)) # Head divergence loss loss_headdiv = 0. if self.headdiv_loss_weight > 0.: # Calculate variance over all heads across all layers js = torch.arange(xtime, dtype=torch.float).cuda(self.device_id) js = js.repeat([bs, n_heads, ymax, 1]) avg_head_pos = sum([ (js * aws).sum(3).sum(1) for aws in xy_aws_layers ]) / (n_heads * self.n_layers) # `[B, L]` loss_headdiv = sum([((js * aws).sum(3).sum(1) - avg_head_pos)**2 for aws in xy_aws_layers]) / ( n_heads * self.n_layers) # `[B, L]` loss_headdiv = loss_headdiv.sum() / ylens.sum() # Latency loss loss_latency = 0. if self.latency_metric == 'interval': raise NotImplementedError elif trigger_points is not None: assert self.latency_loss_weight > 0 # Calculate weight average latency js = torch.arange(xtime, dtype=torch.float).cuda(self.device_id) js = js.repeat([bs, n_heads, ymax, 1]) weighted_avg_head_pos = torch.cat( [(js * aws).sum(3) for aws in xy_aws_layers], dim=1) # `[B, H_mono * n_layers, L]` weighted_avg_head_pos *= torch.softmax( weighted_avg_head_pos.clone(), dim=1) trigger_points = trigger_points.float().cuda( self.device_id) # `[B, L]` trigger_points = trigger_points.unsqueeze(1) if self.latency_metric == 'ctc_sync': loss_latency = torch.abs( weighted_avg_head_pos - trigger_points) # `[B, H_mono * n_layers, L]` else: raise NotImplementedError(self.latency_metric) # NOTE: trigger_points are padded with 0 loss_latency = loss_latency.sum() / ylens.sum() # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl, loss_quantity, loss_headdiv, loss_latency
def forward_att(self, eouts, elens, ys, return_logits=False): """Compute XE loss for the sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity """ bs = eouts.size(0) # Append <sos> and <eos> ys_in_pad, ys_out_pad, ylens = self.append_sos_eos(ys, self.bwd) # Create the self-attention mask bs, ymax = ys_in_pad.size()[:2] yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat( [1, ymax, 1]) yy_mask = yy_mask.unsqueeze(1).repeat([1, self.attn_n_heads, 1, 1]) subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).repeat( [bs, self.attn_n_heads, 1, 1]) yy_mask = yy_mask & subsequent_mask # Create the source-target mask xmax = eouts.size(1) x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat( [1, ymax, 1]) y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) xy_mask = (x_mask * y_mask).unsqueeze(1).repeat( [1, self.attn_n_heads, 1, 1]) ys_emb = self.pos_enc(self.embed(ys_in_pad)) for l in range(self.n_layers): ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts, xy_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws)) ys_emb = self.norm_out(ys_emb) logits = self.output(ys_emb) # for knowledge distillation if return_logits: return logits # Compute XE sequence loss if self.lsm_prob > 0 and self.training: # Label smoothing loss = cross_entropy_lsm(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), self.lsm_prob, self.pad) else: loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), ignore_index=self.pad, size_average=True) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out_pad, self.pad) ppl = min(np.exp(loss.item()), np.inf) # scale loss for CTC loss *= ylens.float().mean() return loss, acc, ppl