def __call__(self, logits, elens, ys, ylens): """Forced alignment with references. Args: logits (FloatTensor): `[B, T, vocab]` elens (List): length `[B]` ys (List): length `[B]`, each of which contains a list of size `[L]` ylens (List): length `[B]` Returns: trigger_points (IntTensor): `[B, L]` """ with torch.no_grad(): ys = [ np2tensor(np.fromiter(y, dtype=np.int64), logits.device) for y in ys ] ys_in_pad = pad_list(ys, 0) # zero padding mask = make_pad_mask(elens.to(logits.device)) mask = mask.unsqueeze(2).expand_as(logits) logits = logits.masked_fill_(mask == 0, self.log0) log_probs = torch.log_softmax(logits, dim=-1).transpose( 0, 1) # `[T, B, vocab]` trigger_points = self.align(log_probs, elens, ys_in_pad, ylens) return trigger_points
def forward(self, xs, xlens): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim (+Δ, ΔΔ)]` xlens (IntTensor): `[B]` Returns: xs (FloatTensor): `[B, T, input_dim]` """ residual = xs xs = self.layers(xs) # `[B, T, input_dim]` # padding device_id = torch.cuda.device_of(next(self.parameters())).idx mask = make_pad_mask(xlens, device_id).unsqueeze(2) # `[B, T, 1]` xs = xs.clone().masked_fill_(mask == 0, 0) # time average denom = xlens.float().unsqueeze(1) if device_id >= 0: denom = denom.cuda(device_id) xs = xs.sum(1) / denom xs = residual + self.proj(xs).unsqueeze(1) return xs
def forward(self, xs, xlens): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim (+Δ, ΔΔ)]` xlens (IntTensor): `[B]` Returns: xs (FloatTensor): `[B, T', input_dim]` """ bs, time = xs.size()[:2] s = xs.clone() for l in range(self.n_layers - 1): s = torch.tanh(self.ssn[l](s)) s = self.ssn[self.n_layers - 1](s) # `[B, T, input_dim]` # padding device_id = torch.cuda.device_of(next(self.parameters())).idx mask = make_pad_mask(xlens, device_id).unsqueeze(2) s = s.masked_fill_(mask == 0, 0) # time average s = s.sum(1) / xlens.float().cuda(device_id).unsqueeze(1) xs = xs + self.p(s).unsqueeze(1) return xs
def forward(self, xs, xlens, task): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks before RNN layers xs, xlens = self.conv(xs, xlens) # Create the self-attention mask bs, xmax = xs.size()[:2] xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(1).expand( bs, xmax, xmax) xx_mask = xx_mask.unsqueeze(1).expand(bs, self.attn_n_heads, xmax, xmax) xs = self.pos_enc(xs) for l in range(self.n_layers): xs, xx_aws = self.layers[l](xs, xx_mask) if not self.training: setattr(self, 'xx_aws_layer%d' % l, tensor2np(xx_aws)) xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) eouts['ys']['xs'] = xs eouts['ys']['xlens'] = xlens return eouts
def time_restricted_mask(xs, xlens, N_l, N_c, N_r, n_chunks): xx_mask = make_pad_mask(xlens.to(xs.device)) xx_mask = xx_mask.unsqueeze(1).repeat( [1, xs.size(1), 1]) # `[B, emax (query), emax (key)]` xx_mask_first = xx_mask.clone() for chunk_idx in range(n_chunks): offset = chunk_idx * N_c # for first layer xx_mask_first[:, offset:offset + N_c, :max(0, offset - N_l)] = 0 xx_mask_first[:, offset:offset + N_c, offset + (N_c + N_r):] = 0 # for upper layers xx_mask[:, offset:offset + N_c, :max(0, offset - N_l)] = 0 xx_mask[:, offset:offset + N_c, offset + N_c:] = 0 return xx_mask_first, xx_mask
def make_san_mask(xs, xlens, unidirectional=False, lookahead=0): """Mask self-attention mask. Args: xs (FloatTensor): `[B, T, d_model]` xlens (InteTensor): `[B]` (on CPU) unidirectional (bool): pad future context lookahead (int): lookahead frame Returns: xx_mask (ByteTensor): `[B, T (query), T (key)]` """ xx_mask = make_pad_mask(xlens.to(xs.device)) xx_mask = xx_mask.unsqueeze(1).repeat([1, xlens.max(), 1]) # `[B, emax (query), emax (key)]` if unidirectional: xx_mask = causal(xx_mask, lookahead) return xx_mask
def decode(self, ys, state=None, is_asr=False): """Decode function. Args: ys (FloatTensor): `[B, L]` state: previous tokens is_asr (bool): Returns: ys_emb (FloatTensor): `[B, L, n_units]` state: previous tokens """ # Concatenate previous tokens if is_asr and state is not None: ys = torch.cat([state, ys], dim=1) # NOTE: this is used for ASR decoding ys_emb = self.embed(ys.long()) # Create the self-attention mask bs, ymax = ys_emb.size()[:2] ylens = torch.IntTensor([ymax] * bs) yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand( bs, ymax, ymax) yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax, ymax) subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand( bs, self.attn_n_heads, ymax, ymax) yy_mask = yy_mask & subsequent_mask ys_emb = self.pos_enc(ys_emb) for l in range(self.n_layers): ys_emb, yy_aws, _ = self.layers[l](ys_emb, yy_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) ys_emb = self.norm_out(ys_emb) if is_asr: state = ys return ys_emb, state
def decode(self, ys, ys_prev=None, cache=False): """Decode function. Args: ys (LongTensor): `[B, L]` ys_prev (LongTensor): previous tokens cahce (bool): concatenate previous tokens Returns: logits (FloatTensor): `[B, L, vocab]` ys_emb (FloatTensor): `[B, L, d_model]` (for ys_prev) ys_prev (LongTensor): previous tokens """ # Concatenate previous tokens if cache and ys_prev is not None: ys = torch.cat([ys_prev, ys], dim=1) # NOTE: this is used for ASR decoding # Create the self-attention mask bs, ymax = ys.size()[:2] ylens = torch.IntTensor([ymax] * bs) tgt_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat( [1, ymax, 1]) subsequent_mask = tgt_mask.new_ones(ymax, ymax).byte() subsequent_mask = torch.tril(subsequent_mask, out=subsequent_mask).unsqueeze(0) tgt_mask = tgt_mask & subsequent_mask out = self.pos_enc(self.embed(ys.long())) for l in range(self.n_layers): out, yy_aws, _ = self.layers[l](out, tgt_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) out = self.norm_out(out) if self.adaptive_softmax is None: logits = self.output(out) else: logits = out return logits, out, ys
def forward(self, xs, xlens): """Forward pass. Args: xs (FloatTensor): `[B, T, input_dim (+Δ, ΔΔ)]` xlens (IntTensor): `[B]` Returns: xs (FloatTensor): `[B, T, input_dim]` """ residual = xs xs = self.layers(xs) # `[B, T, input_dim]` # padding xlens = xlens.to(xs.device) mask = make_pad_mask(xlens).unsqueeze(2) # `[B, T, 1]` xs = xs.clone().masked_fill_(mask == 0, 0) # time average denom = xlens.float().unsqueeze(1) xs = xs.sum(1) / denom xs = residual + self.proj(xs).unsqueeze(1) return xs
def forward_att(self, eouts, elens, ys, return_logits=False): """Compute XE loss for the sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation Returns: loss (FloatTensor): `[1]` acc (float): ppl (float): """ bs = eouts.size(0) # Append <sos> and <eos> eos = eouts.new_zeros(1).fill_(self.eos).long() ys = [ np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int64), self.device_id) for y in ys ] ylens = np2tensor( np.fromiter([y.size(0) + 1 for y in ys], dtype=np.int32)) # +1 for <eos> ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys], self.pad) ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys], self.pad) # Create the self-attention mask bs, ymax = ys_in_pad.size()[:2] yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand( bs, ymax, ymax) yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax, ymax) subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand( bs, self.attn_n_heads, ymax, ymax) yy_mask = yy_mask & subsequent_mask # Create the source-target mask xmax = eouts.size(1) x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand( bs, ymax, xmax) y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).expand( bs, ymax, xmax) xy_mask = (x_mask * y_mask).unsqueeze(1).expand( bs, self.attn_n_heads, ymax, xmax) ys_emb = self.pos_enc(self.embed(ys_in_pad)) for l in range(self.n_layers): ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts, xy_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws)) logits = self.norm_out(ys_emb) if self.adaptive_softmax is None: logits = self.output(logits) if return_logits: return logits # Compute XE sequence loss if self.adaptive_softmax is None: if self.lsm_prob > 0 and self.training: # Label smoothing loss = cross_entropy_lsm(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), self.lsm_prob, self.pad) else: loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), ignore_index=self.pad, size_average=True) # Focal loss if self.focal_loss_weight > 0: fl = focal_loss(logits, ys_out_pad, ylens, alpha=self.focal_loss_weight, gamma=self.focal_loss_gamma) loss = loss * ( 1 - self.focal_loss_weight) + fl * self.focal_loss_weight else: loss = self.adaptive_softmax(logits.view((-1, logits.size(2))), ys_out_pad.view(-1)).loss # Compute token-level accuracy in teacher-forcing if self.adaptive_softmax is None: acc = compute_accuracy(logits, ys_out_pad, self.pad) else: acc = compute_accuracy( self.adaptive_softmax.log_prob( logits.view((-1, logits.size(2)))), ys_out_pad, self.pad) ppl = min(np.exp(loss.item()), np.inf) # scale loss for CTC loss *= ylens.float().mean() return loss, acc, ppl
def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } N_l = self.chunk_size_left N_c = self.chunk_size_current N_r = self.chunk_size_right bs, xmax, idim = xs.size() if self.latency_controlled: xs = chunkwise(xs, N_l, N_c, N_r) if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens) if not self.training: self.data_dict['elens'] = tensor2np(xlens) if self.latency_controlled: # streaming Conformer encoder _N_l = max(0, N_l // self.subsampling_factor) _N_c = N_c // self.subsampling_factor n_chunks = math.ceil(xs.size(0) / bs) emax = math.ceil(xmax / self.subsampling_factor) xs = xs * self.scale pos_idxs = torch.arange(xs.size(1) - 1, -1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) xx_mask = None # NOTE: no mask for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs) if not self.training: n_heads = layer.xx_aws.size(1) xx_aws = layer.xx_aws[:, :, _N_l:_N_l + _N_c, _N_l:_N_l + _N_c] xx_aws = xx_aws.view(bs, n_chunks, n_heads, _N_c, _N_c) xx_aws_center = xx_aws.new_zeros(bs, n_heads, emax, emax) for chunk_idx in range(n_chunks): offset = chunk_idx * _N_c emax_blc = xx_aws_center[:, :, offset:offset + _N_c].size(2) xx_aws_chunk = xx_aws[:, chunk_idx, :, :emax_blc, : emax_blc] xx_aws_center[:, :, offset:offset + _N_c, offset:offset + _N_c] = xx_aws_chunk self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws_center) # Extract the center region xs = xs[:, _N_l:_N_l + _N_c] # `[B * n_chunks, _N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] else: bs, xmax, idim = xs.size() xs = xs * self.scale # Create the self-attention mask xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) pos_idxs = torch.arange(xmax - 1, -1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs) if not self.training: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.layer_sub1( xs, xx_mask, pos_embs=pos_embs ) if self.task_specific_layer else xs.clone() xs_sub1 = self.norm_out_sub1(xs_sub1) if self.bridge_sub1 is not None: xs_sub1 = self.bridge_sub1(xs_sub1) if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.layer_sub2( xs, xx_mask, pos_embs=pos_embs ) if self.task_specific_layer else xs.clone() xs_sub2 = self.norm_out_sub2(xs_sub2) if self.bridge_sub2 is not None: xs_sub2 = self.bridge_sub2(xs_sub2) if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } if self.latency_controlled: bs, xmax, idim = xs.size() n_blocks = xmax // self.N_c if xmax % self.N_c != 0: n_blocks += 1 xs_tmp = xs.new_zeros(bs, n_blocks, self.N_l + self.N_c + self.N_r, idim) xs_pad = torch.cat([ xs.new_zeros(bs, self.N_l, idim), xs, xs.new_zeros(bs, self.N_r, idim) ], dim=1) for blc_id, t in enumerate( range(self.N_l, self.N_l + xmax, self.N_c)): xs_chunk = xs_pad[:, t - self.N_l:t + (self.N_c + self.N_r)] xs_tmp[:, blc_id, :xs_chunk.size(1), :] = xs_chunk xs = xs_tmp.view(bs * n_blocks, self.N_l + self.N_c + self.N_r, idim) if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens) if not self.training: self.data_dict['elens'] = tensor2np(xlens) if self.latency_controlled: N_l = max(0, self.N_l // self.subsampling_factor) N_c = self.N_c // self.subsampling_factor emax = xmax // self.subsampling_factor if xmax % self.subsampling_factor != 0: emax += 1 xs = self.pos_enc(xs, scale=True) xx_mask = None for lth, layer in enumerate(self.layers): xs, xx_aws = layer(xs, xx_mask) if not self.training: n_heads = xx_aws.size(1) xx_aws = xx_aws[:, :, N_l:N_l + N_c, N_l:N_l + N_c] xx_aws = xx_aws.view(bs, n_blocks, n_heads, N_c, N_c) xx_aws_center = xx_aws.new_zeros(bs, n_heads, emax, emax) for blc_id in range(n_blocks): offset = blc_id * N_c emax_blc = xx_aws_center[:, :, offset:offset + N_c].size(2) xx_aws_chunk = xx_aws[:, blc_id, :, :emax_blc, :emax_blc] xx_aws_center[:, :, offset:offset + N_c, offset:offset + N_c] = xx_aws_chunk self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws_center) # Extract the center region xs = xs[:, N_l:N_l + N_c] # `[B * n_blocks, N_c // subsampling_factor, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] else: bs, xmax, idim = xs.size() xs = self.pos_enc(xs, scale=True) # Create the self-attention mask xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) for lth, layer in enumerate(self.layers): xs, xx_aws = layer(xs, xx_mask) if not self.training: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.layer_sub1( xs, xx_mask )[0] if self.task_specific_layer else xs.clone() xs_sub1 = self.norm_out_sub1(xs_sub1) if self.bridge_sub1 is not None: xs_sub1 = self.bridge_sub1(xs_sub1) if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.layer_sub2( xs, xx_mask )[0] if self.task_specific_layer else xs.clone() xs_sub2 = self.norm_out_sub2(xs_sub2) if self.bridge_sub2 is not None: xs_sub2 = self.bridge_sub2(xs_sub2) if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def align(self, logits, elens, ys, ylens, add_eos=True): """Calculte the best CTC alignment with the forward-backward algorithm. Args: logits (FloatTensor): `[B, T, vocab]` elens (FloatTensor): `[B]` ys (FloatTensor): `[B, L]` ylens (FloatTensor): `[B]` add_eos (bool): Use the last time index as a boundary corresponding to <eos> Returns: trigger_points (IntTensor): `[B, L]` """ bs, xmax, vocab = logits.size() device = logits.device # zero padding mask = make_pad_mask(elens.to(device)) mask = mask.unsqueeze(2).repeat([1, 1, vocab]) logits = logits.masked_fill_(mask == 0, self.log0) log_probs = torch.log_softmax(logits, dim=-1).transpose(0, 1) # `[T, B, vocab]` path = _label_to_path(ys, self.blank) path_lens = 2 * ylens.long() + 1 ymax = ys.size(1) max_path_len = path.size(1) assert ys.size() == (bs, ymax), ys.size() assert path.size() == (bs, ymax * 2 + 1) alpha = log_probs.new_zeros(bs, max_path_len).fill_(self.log0) alpha[:, 0] = LOG_1 beta = alpha.clone() gamma = alpha.clone() batch_index = torch.arange(bs, dtype=torch.int64).unsqueeze(1) seq_index = torch.arange(xmax, dtype=torch.int64).unsqueeze(1).unsqueeze(2) log_probs_fwd_bwd = log_probs[seq_index, batch_index, path] # forward algorithm for t in range(xmax): alpha = self._computes_transition(alpha, path, path_lens, log_probs_fwd_bwd[t], log_probs[t]) # backward algorithm r_path = _flip_path(path, path_lens) log_probs_inv = _flip_label_probability(log_probs, elens.long()) # `[T, B, vocab]` log_probs_fwd_bwd = _flip_path_probability(log_probs_fwd_bwd, elens.long(), path_lens) # `[T, B, 2*L+1]` for t in range(xmax): beta = self._computes_transition(beta, r_path, path_lens, log_probs_fwd_bwd[t], log_probs_inv[t]) # pick up the best CTC path best_aligns = log_probs.new_zeros((bs, xmax), dtype=torch.int64) # forward algorithm log_probs_fwd_bwd = _flip_path_probability(log_probs_fwd_bwd, elens.long(), path_lens) for t in range(xmax): gamma = self._computes_transition(gamma, path, path_lens, log_probs_fwd_bwd[t], log_probs[t], skip_accum=True) # select paths where gamma is valid log_probs_fwd_bwd[t] = log_probs_fwd_bwd[t].masked_fill_(gamma == self.log0, self.log0) # pick up the best alignment offsets = log_probs_fwd_bwd[t].argmax(1) for b in range(bs): if t <= elens[b] - 1: token_idx = path[b, offsets[b]] best_aligns[b, t] = token_idx # remove the rest of paths gamma = log_probs.new_zeros(bs, max_path_len).fill_(self.log0) for b in range(bs): gamma[b, offsets[b]] = LOG_1 # pick up trigger points trigger_aligns = torch.zeros((bs, xmax), dtype=torch.int64) trigger_points = log_probs.new_zeros((bs, ymax + 1), dtype=torch.int32) # +1 for <eos> for b in range(bs): n_triggers = 0 if add_eos: trigger_points[b, ylens[b]] = elens[b] - 1 # NOTE: use the last time index as a boundary corresponding to <eos> # Otherwise, index: 0 is used for <eos> for t in range(elens[b]): token_idx = best_aligns[b, t] if token_idx == self.blank: continue if not (t == 0 or token_idx != best_aligns[b, t - 1]): continue # NOTE: select the most left trigger points trigger_aligns[b, t] = token_idx trigger_points[b, n_triggers] = t n_triggers += 1 assert ylens.sum() == (trigger_aligns != 0).sum() return trigger_points
def forward_att(self, eouts, elens, ys, return_logits=False): """Compute XE loss for the sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity """ bs = eouts.size(0) # Append <sos> and <eos> ys_in_pad, ys_out_pad, ylens = self.append_sos_eos(ys, self.bwd) # Create the self-attention mask bs, ymax = ys_in_pad.size()[:2] yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat( [1, ymax, 1]) yy_mask = yy_mask.unsqueeze(1).repeat([1, self.attn_n_heads, 1, 1]) subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).repeat( [bs, self.attn_n_heads, 1, 1]) yy_mask = yy_mask & subsequent_mask # Create the source-target mask xmax = eouts.size(1) x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat( [1, ymax, 1]) y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) xy_mask = (x_mask * y_mask).unsqueeze(1).repeat( [1, self.attn_n_heads, 1, 1]) ys_emb = self.pos_enc(self.embed(ys_in_pad)) for l in range(self.n_layers): ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts, xy_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws)) ys_emb = self.norm_out(ys_emb) logits = self.output(ys_emb) # for knowledge distillation if return_logits: return logits # Compute XE sequence loss if self.lsm_prob > 0 and self.training: # Label smoothing loss = cross_entropy_lsm(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), self.lsm_prob, self.pad) else: loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), ignore_index=self.pad, size_average=True) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out_pad, self.pad) ppl = min(np.exp(loss.item()), np.inf) # scale loss for CTC loss *= ylens.float().mean() return loss, acc, ppl
def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks before RNN layers xs, xlens = self.conv(xs, xlens) if not self.training: self.data_dict['elens'] = tensor2np(xlens) bs, xmax, idim = xs.size() if self.memory_transformer or self.latency_controlled: # streaming Transformer(XL) encoder N_l = max(0, self.chunk_size_left // self.subsampling_factor) N_c = self.chunk_size_cur // self.subsampling_factor N_r = max(0, self.chunk_size_right // self.subsampling_factor) xs_chunks = [] xx_aws = [[] for _ in range(self.n_layers)] mems = self.init_memory() self.reset_cache() # for LC-BLSTM mlen = 0 for t in range(0, xmax, N_c): clen = min(N_c, xmax - 1 - t + 1) rlen = 0 if xmax - 1 - (t + clen) + 1 > 0: rlen = min(N_r, xmax - 1 - (t + clen) + 1) xs_chunk = xs[:, t:t + (clen + rlen)] if self.hybrid_rnn: for lth in range(self.n_layers_rnn): self.rnn[lth].flatten_parameters() # for multi-GPUs self.rnn_bwd[lth].flatten_parameters( ) # for multi-GPUs # bwd xs_chunk_bwd = torch.flip(xs_chunk, dims=[1]) xs_chunk_bwd, _ = self.rnn_bwd[lth](xs_chunk_bwd, hx=None) xs_chunk_bwd = torch.flip( xs_chunk_bwd, dims=[1]) # `[B, clen+rlen, d_model]` # fwd if xs_chunk.size(1) <= clen: xs_chunk_fwd, self.fwd_states[lth] = self.rnn[lth]( xs_chunk, hx=self.fwd_states[lth]) else: xs_chunk_fwd1, self.fwd_states[lth] = self.rnn[ lth](xs_chunk[:, :clen], hx=self.fwd_states[lth]) xs_chunk_fwd2, _ = self.rnn[lth]( xs_chunk[:, clen:], hx=self.fwd_states[lth]) xs_chunk_fwd = torch.cat( [xs_chunk_fwd1, xs_chunk_fwd2], dim=1) # `[B, clen+rlen, d_model]` # NOTE: xs_chunk_fwd2 is for xs_chunk_bwd in the next layer if self.bidir_sum: xs_chunk = xs_chunk_fwd + xs_chunk_bwd else: xs_chunk = torch.cat([xs_chunk_fwd, xs_chunk_bwd], dim=-1) xs_chunk = self.dropout_rnn(xs_chunk) if self.proj is not None: xs_chunk = self.proj(xs_chunk) xs_chunk = self.pos_enc(xs_chunk, scale=True) # for scale if self.memory_transformer: # adopt zero-centered offset pos_idxs = torch.arange(mlen - 1, -xs_chunk.size(1) - 1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) hidden_states = [xs_chunk[:, :clen][:, -N_l:]] for lth, (mem, layer) in enumerate(zip(mems, self.layers)): if self.memory_transformer: xs_chunk, xx_aws_chunk = layer(xs_chunk, None, pos_embs=pos_embs, memory=mem, u=self.u, v=self.v) # no mask else: xs_chunk, xx_aws_chunk = layer(xs_chunk, None, memory=mem) # no mask if lth < self.n_layers - 1: hidden_states.append(xs_chunk[:, :clen][:, -N_l:]) # NOTE: xx_aws_chunk: `[B, H, clen+rlen (query), mlen+clen+rlen (key)]` xx_aws_chunk = xx_aws_chunk[:, :, :clen, mlen:mlen + clen] assert xx_aws_chunk.size(2) == xx_aws_chunk.size(3) xx_aws_chunk_pad = xs.new_zeros( (bs, xx_aws_chunk.size(1), N_c, N_c)) xx_aws_chunk_pad[:, :, :xx_aws_chunk.size(2), :xx_aws_chunk .size(3)] = xx_aws_chunk xx_aws[lth].append(xx_aws_chunk_pad) mems = self.update_memory(mems, hidden_states) mlen = mems[0].size(1) if mems[0].dim() > 1 else 0 xs_chunks.append(xs_chunk[:, :clen]) xs = torch.cat(xs_chunks, dim=1)[:, :xmax] if not self.training: for lth in range(self.n_layers): self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( torch.cat(xx_aws[lth], dim=3)[:, :, :xmax, :xmax]) else: # Hybrid RNN-Transformer if self.hybrid_rnn: for lth in range(self.n_layers_rnn): self.rnn[lth].flatten_parameters() # for multi-GPUs self.rnn_bwd[lth].flatten_parameters() # for multi-GPUs # bwd xs_bwd = torch.flip(xs, dims=[1]) xs_bwd, _ = self.rnn_bwd[lth](xs_bwd, hx=None) xs_bwd = torch.flip(xs_bwd, dims=[1]) # fwd xs_fwd, _ = self.rnn[lth](xs, hx=None) # NOTE: no padding because inputs are not sorted if self.bidir_sum: xs = xs_fwd + xs_bwd else: xs = torch.cat([xs_fwd, xs_bwd], dim=-1) xs = self.dropout_rnn(xs) if self.proj is not None: xs = self.proj(xs) xs = self.pos_enc(xs, scale=True) # Create the self-attention mask xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) for lth, layer in enumerate(self.layers): xs, xx_aws = layer(xs, xx_mask) if not self.training: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.layer_sub1( xs, xx_mask )[0] if self.task_specific_layer else xs.clone() xs_sub1 = self.norm_out_sub1(xs_sub1) if self.bridge_sub1 is not None: xs_sub1 = self.bridge_sub1(xs_sub1) if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.layer_sub2( xs, xx_mask )[0] if self.task_specific_layer else xs.clone() xs_sub2 = self.norm_out_sub2(xs_sub2) if self.bridge_sub2 is not None: xs_sub2 = self.bridge_sub2(xs_sub2) if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def forward_att(self, eouts, elens, ys, return_logits=False, teacher_logits=None, trigger_points=None): """Compute XE loss for the Transformer decoder. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation teacher_logits (FloatTensor): `[B, L, vocab]` trigger_points (IntTensor): `[B, T]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity loss_quantity (FloatTensor): `[1]` loss_headdiv (FloatTensor): `[1]` loss_latency (FloatTensor): `[1]` """ # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos, self.pad, self.bwd) if not self.training: self.data_dict['elens'] = tensor2np(elens) self.data_dict['ylens'] = tensor2np(ylens) self.data_dict['ys'] = tensor2np(ys_out) # Create target self-attention mask xtime = eouts.size(1) bs, ymax = ys_in.size()[:2] tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1]) causal_mask = tgt_mask.new_ones(ymax, ymax).byte() causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0) tgt_mask = tgt_mask & causal_mask # `[B, L, L]` # Create source-target mask src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat( [1, ymax, 1]) # `[B, L, T]` # external LM integration lmout = None if self.lm is not None: self.lm.eval() with torch.no_grad(): lmout, lmstate, _ = self.lm.predict(ys_in, None) lmout = self.lm_output_proj(lmout) out = self.embed(ys_in) mlen = 0 # TODO: fix later if self.memory_transformer: # NOTE: TransformerXL does not use positional encoding in the token embedding mems = self.init_memory() # adopt zero-centered offset pos_idxs = torch.arange(mlen - 1, -ymax - 1, -1.0, dtype=torch.float) if self.device_id >= 0: pos_idxs = pos_idxs.cuda(self.device_id) pos_embs = self.dropout_emb(self.pos_emb(pos_idxs)) out = self.dropout_emb(out) hidden_states = [out] else: out = self.pos_enc(out) xy_aws_layers = [] for l, layer in enumerate(self.layers): if self.memory_transformer: out, yy_aws, xy_aws, xy_aws_beta, yy_aws_lm = layer( out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout, pos_embs=pos_embs, memory=mems[l], u=self.u, v=self.v) hidden_states.append(out) else: out, yy_aws, xy_aws, xy_aws_beta, yy_aws_lm = layer( out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout) xy_aws_layers.append(xy_aws.clone() if xy_aws is not None else out. new_zeros(bs, yy_aws.size(1), ymax, xtime)) if not self.training: if yy_aws is not None: self.aws_dict['yy_aws_layer%d' % l] = tensor2np(yy_aws) if xy_aws is not None: self.aws_dict['xy_aws_layer%d' % l] = tensor2np(xy_aws) if xy_aws_beta is not None: self.aws_dict['xy_aws_beta_layer%d' % l] = tensor2np(xy_aws_beta) if yy_aws_lm is not None: self.aws_dict['yy_aws_lm_layer%d' % l] = tensor2np(yy_aws_lm) logits = self.output(self.norm_out(out)) # TODO: Update memory # if self.memory_transformer: # new_mems = self.update_memory(mems, hidden_states) # for knowledge distillation if return_logits: return logits # Compute XE sequence loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) # Attention padding if self.quantity_loss_weight > 0 or self.headdiv_loss_weight > 0 or self.latency_loss_weight > 0: for l in range(self.mocha_first_layer - 1, self.n_layers): n_heads = xy_aws_layers[l].size(1) xy_aws_layers[l] = xy_aws_layers[l].masked_fill_( src_mask.unsqueeze(1).repeat([1, n_heads, 1, 1]) == 0, 0) xy_aws_layers[l] = xy_aws_layers[l].masked_fill_( tgt_mask[:, :, -1:].unsqueeze(1).repeat([1, n_heads, 1, xtime]) == 0, 0) # NOTE: attention padding is quite effective for quantity loss n_heads = xy_aws_layers[-1].size(1) # mono # NOTE: debug for multihead mono + multihead chunk # Quantity loss loss_quantity = 0. if 'mocha' in self.attn_type: # Average over all heads across all layers n_tokens_ref = tgt_mask[:, -1, :].sum(1).float() # `[B]` # NOTE: count <eos> tokens n_tokens_pred = sum([ torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1)) for aws in xy_aws_layers[self.mocha_first_layer - 1:] ]) # `[B]` n_tokens_pred /= (self.n_layers - self.mocha_first_layer + 1) loss_quantity = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref)) # Head divergence loss loss_headdiv = 0. if self.headdiv_loss_weight > 0.: # Calculate variance over all heads across all layers js = torch.arange(xtime, dtype=torch.float).cuda(self.device_id) js = js.repeat([bs, n_heads, ymax, 1]) avg_head_pos = sum([ (js * aws).sum(3).sum(1) for aws in xy_aws_layers ]) / (n_heads * self.n_layers) # `[B, L]` loss_headdiv = sum([((js * aws).sum(3).sum(1) - avg_head_pos)**2 for aws in xy_aws_layers]) / ( n_heads * self.n_layers) # `[B, L]` loss_headdiv = loss_headdiv.sum() / ylens.sum() # Latency loss loss_latency = 0. if self.latency_metric == 'interval': raise NotImplementedError elif trigger_points is not None: assert self.latency_loss_weight > 0 # Calculate weight average latency js = torch.arange(xtime, dtype=torch.float).cuda(self.device_id) js = js.repeat([bs, n_heads, ymax, 1]) weighted_avg_head_pos = torch.cat( [(js * aws).sum(3) for aws in xy_aws_layers], dim=1) # `[B, H_mono * n_layers, L]` weighted_avg_head_pos *= torch.softmax( weighted_avg_head_pos.clone(), dim=1) trigger_points = trigger_points.float().cuda( self.device_id) # `[B, L]` trigger_points = trigger_points.unsqueeze(1) if self.latency_metric == 'ctc_sync': loss_latency = torch.abs( weighted_avg_head_pos - trigger_points) # `[B, H_mono * n_layers, L]` else: raise NotImplementedError(self.latency_metric) # NOTE: trigger_points are padded with 0 loss_latency = loss_latency.sum() / ylens.sum() # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl, loss_quantity, loss_headdiv, loss_latency
def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward pass. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (InteTensor): `[B]` (on CPU) task (str): ys/ys_sub1/ys_sub2 use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (InteTensor): `[B]` (on CPU) """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } N_l = self.chunk_size_left N_c = self.chunk_size_current N_r = self.chunk_size_right bs, xmax, idim = xs.size() n_chunks = 0 if self.latency_controlled: if self.lc_type == 'reshape': xs = chunkwise(xs, N_l, N_c, N_r) # `[B * n_chunks, N_l+N_c+N_r, idim]` elif self.lc_type == 'mask': # xs = chunkwise(xs, N_l, N_c, N_r) # `[B * n_chunks, N_l+N_c+N_r, idim]` xs = chunkwise(xs, 0, N_c, 0) # `[B * n_chunks, N_c, idim]` else: raise ValueError n_chunks = xs.size(0) // bs if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens) N_l = max(0, N_l // self.conv.subsampling_factor) N_c = N_c // self.conv.subsampling_factor N_r = N_r // self.conv.subsampling_factor if self.lc_type == 'mask': # Extract the center region emax = xlens.max().item() # xs = xs[:, N_l:N_l + N_c] # `[B * n_chunks, N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] # `[B, emax, d_model]` if self.latency_controlled: # streaming Transformer encoder emax = xlens.max().item() pos_embs = None if self.pe_type == 'relative': xs = xs * self.scale pos_idxs = torch.arange(xs.size(1) - 1, -1, -1.0, dtype=torch.float, device=self.device) pos_embs = self.pos_emb(pos_idxs) else: xs = self.pos_enc(xs, scale=True) xx_mask = None # NOTE: no mask to avoid all masked region for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs) if not self.training: if self.lc_type == 'reshape': n_heads = layer.xx_aws.size(1) xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c, N_l:N_l + N_c] xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c) xx_aws_center = xx_aws.new_zeros( bs, n_heads, emax, emax) for chunk_idx in range(n_chunks): offset = chunk_idx * N_c emax_chunk = xx_aws_center[:, :, offset:offset + N_c].size(2) xx_aws_chunk = xx_aws[:, chunk_idx, :, : emax_chunk, :emax_chunk] xx_aws_center[:, :, offset:offset + N_c, offset:offset + N_c] = xx_aws_chunk elif self.lc_type == 'mask': self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) else: raise ValueError self.data_dict['elens%d' % lth] = tensor2np(xlens) if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) emax = xlens.max().item() N_l = max(0, N_l // self.subsample[lth].subsampling_factor) N_c = N_c // self.subsample[lth].subsampling_factor N_r = N_r // self.subsample[lth].subsampling_factor if self.lc_type == 'mask': xx_mask = make_pad_mask(xlens.to(self.device)) xx_mask = xx_mask.unsqueeze(1).repeat( [1, xs.size(1), 1]) # `[B, emax (query), emax (key)]` for chunk_idx in range(n_chunks): offset = chunk_idx * N_c xx_mask[:, offset:offset + N_c, :max(0, offset - N_l)] = 0 xx_mask[:, offset:offset + N_c, offset + (N_c + N_r):] = 0 # Extract the center region if self.lc_type == 'reshape': xs = xs[:, N_l:N_l + N_c] # `[B * n_chunks, N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] else: if self.pe_type == 'relative': xs = xs * self.scale # Create sinusoidal positional embeddings for relative positional encoding pos_idxs = torch.arange(xs.size(1) - 1, -1, -1.0, dtype=torch.float, device=self.device) pos_embs = self.pos_emb(pos_idxs) else: xs = self.pos_enc(xs, scale=True) pos_embs = None # Create the self-attention mask xx_mask = make_pad_mask(xlens.to(self.device)).unsqueeze(1).repeat( [1, xs.size(1), 1]) for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs) if not self.training: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.sub_module(xs, xx_mask, lth, pos_embs, 'sub1') if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.sub_module(xs, xx_mask, lth, pos_embs, 'sub2') if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) # Create the self-attention mask xx_mask = make_pad_mask(xlens.to( self.device)).unsqueeze(1).repeat([1, xs.size(1), 1]) if self.pe_type == 'relative': # Create sinusoidal positional embeddings for relative positional encoding pos_idxs = torch.arange(xs.size(1) - 1, -1, -1.0, dtype=torch.float, device=self.device) pos_embs = self.pos_emb(pos_idxs) xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def forward_att(self, eouts, elens, ys, return_logits=False, teacher_logits=None, trigger_points=None): """Compute XE loss for the Transformer decoder. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation teacher_logits (FloatTensor): `[B, L, vocab]` trigger_points (IntTensor): `[B, T]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity loss_quantity (FloatTensor): `[1]` loss_headdiv (FloatTensor): `[1]` loss_latency (FloatTensor): `[1]` """ # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos, self.pad, self.bwd) if not self.training: self.data_dict['elens'] = tensor2np(elens) self.data_dict['ylens'] = tensor2np(ylens) self.data_dict['ys'] = tensor2np(ys_out) # Create target self-attention mask xmax = eouts.size(1) bs, ymax = ys_in.size()[:2] mlen = 0 tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1]) causal_mask = tgt_mask.new_ones(ymax, ymax).byte() causal_mask = torch.tril(causal_mask, diagonal=0 + mlen, out=causal_mask).unsqueeze(0) tgt_mask = tgt_mask & causal_mask # `[B, L (query), L (key)]` # Create source-target mask src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat([1, ymax, 1]) # `[B, L, T]` # external LM integration lmout = None if self.lm is not None: self.lm.eval() with torch.no_grad(): lmout, lmstate, _ = self.lm.predict(ys_in, None) lmout = self.lm_output_proj(lmout) out = self.pos_enc(self.embed(ys_in)) # scaled mems = self.init_memory() pos_embs = None if self.memory_transformer: out = self.dropout_emb(out) # NOTE: TransformerXL does not use positional encoding in the token embedding # adopt zero-centered offset pos_idxs = torch.arange(mlen - 1, -ymax - 1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) hidden_states = [out] xy_aws_layers = [] for lth, (mem, layer) in enumerate(zip(mems, self.layers)): out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout, pos_embs=pos_embs, memory=mem, u=self.u, v=self.v) if lth < self.n_layers - 1: hidden_states.append(out) # NOTE: outputs from the last layer is not used for momory # Attention padding xy_aws = layer.xy_aws if xy_aws is not None and 'mocha' in self.attn_type: tgt_mask_v2 = (ys_out != self.pad).unsqueeze(1).unsqueeze(3) # `[B, 1, L, 1]` xy_aws = xy_aws.masked_fill_(tgt_mask_v2.repeat([1, xy_aws.size(1), 1, xmax]) == 0, 0) # NOTE: attention padding is quite effective for quantity loss xy_aws_layers.append(xy_aws.clone()) if not self.training: if layer.yy_aws is not None: self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws) if layer.xy_aws is not None: self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws) if layer.xy_aws_beta is not None: self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta) if layer.xy_aws_p_choose is not None: self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose) if layer.yy_aws_lm is not None: self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm) logits = self.output(self.norm_out(out)) # for knowledge distillation if return_logits: return logits # Compute XE loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) losses_auxiliary = {} # Quantity loss losses_auxiliary['loss_quantity'] = 0. if 'mocha' in self.attn_type: # Average over all heads across all layers n_tokens_ref = tgt_mask[:, -1, :].sum(1).float() # `[B]` # NOTE: count <eos> tokens n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1)) for aws in xy_aws_layers]) # `[B]` n_tokens_pred /= len(xy_aws_layers) losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref)) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl, losses_auxiliary
def greedy(self, eouts, elens, max_len_ratio, exclude_eos=False, idx2token=None, refs_id=None, speakers=None, oracle=False): """Greedy decoding in the inference stage (used only for evaluation during training). Args: eouts (FloatTensor): `[B, T, enc_units]` elens (IntTensor): `[B]` max_len_ratio (int): maximum sequence length of tokens exclude_eos (bool): idx2token (): refs_id (list): speakers (list): oracle (bool): Returns: best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]` aw (list): A list of length `[B]`, which contains arrays of size `[L, T]` """ bs, xmax = eouts.size()[:2] # Start from <sos> (<eos> in case of the backward decoder) ys_all = eouts.new_zeros(bs, 1).fill_(self.eos).long() # TODO(hirofumi): Create the source-target mask for batch decoding best_hyps_batch = [] ylens = torch.zeros(bs).int() yy_aws_tmp = [None] * bs xy_aws_tmp = [None] * bs eos_flags = [False] * bs for t in range(int(np.floor(xmax * max_len_ratio)) + 1): # Create the self-attention mask yy_mask = make_pad_mask(ylens + 1, self.device_id).unsqueeze(1).expand( bs, t + 1, t + 1) yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, t + 1, t + 1) subsequent_mask = torch.tril(yy_mask.new_ones( (t + 1, t + 1)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand( bs, self.attn_n_heads, t + 1, t + 1) yy_mask = yy_mask & subsequent_mask # Create the source-target mask xmax = eouts.size(1) x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand( bs, t + 1, xmax) y_mask = make_pad_mask(ylens + 1, self.device_id).unsqueeze(2).expand( bs, t + 1, xmax) xy_mask = (x_mask * y_mask).unsqueeze(1).expand( bs, self.attn_n_heads, t + 1, xmax) out = self.pos_enc(self.embed(ys_all)) for l in range(self.n_layers): out, yy_aws, xy_aws = self.layers[l](out, yy_mask, eouts, xy_mask) out = self.norm_out(out) # Pick up 1-best y = self.output(out).argmax(-1)[:, -1:] best_hyps_batch += [y] # Count lengths of hypotheses for b in range(bs): if not eos_flags[b]: if y[b].item() == self.eos: eos_flags[b] = True yy_aws_tmp[b] = yy_aws[b:b + 1] # TODO: fix this xy_aws_tmp[b] = xy_aws[b:b + 1] ylens[b] += 1 # NOTE: include <eos> # Break if <eos> is outputed in all mini-bs if sum(eos_flags) == bs: break ys_all = torch.cat([ys_all, y], dim=-1) # Concatenate in L dimension best_hyps_batch = torch.cat(best_hyps_batch, dim=1) # xy_aws_tmp = torch.stack(xy_aws_tmp, dim=0) # Convert to numpy best_hyps_batch = tensor2np(best_hyps_batch) # xy_aws_tmp = tensor2np(xy_aws_tmp) # if self.score.attn_n_heads > 1: # xy_aws_tmp = xy_aws_tmp[:, :, :, 0] # # TODO(hirofumi): fix for MHA # Truncate by the first <eos> (<sos> in case of the backward decoder) if self.bwd: # Reverse the order best_hyps = [ best_hyps_batch[b, :ylens[b]][::-1] for b in range(bs) ] # aws = [xy_aws_tmp[b, :ylens[b]][::-1] for b in range(bs)] else: best_hyps = [best_hyps_batch[b, :ylens[b]] for b in range(bs)] # aws = [xy_aws_tmp[b, :ylens[b]] for b in range(bs)] # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.bwd: best_hyps = [ best_hyps[b][1:] if eos_flags[b] else best_hyps[b] for b in range(bs) ] else: best_hyps = [ best_hyps[b][:-1] if eos_flags[b] else best_hyps[b] for b in range(bs) ] # return best_hyps, aws return best_hyps, None
def forward_att(self, eouts, elens, ys, ys_hist=[], return_logits=False, teacher_logits=None): """Compute XE loss for the Transformer model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` ys_hist (list): return_logits (bool): return logits for knowledge distillation teacher_logits (FloatTensor): `[B, L, vocab]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity """ bs = eouts.size(0) # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.pad, self.bwd) # Create the self-attention mask bs, ytime = ys_in.size()[:2] tgt_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat( [1, ytime, 1]) subsequent_mask = tgt_mask.new_ones(ytime, ytime).byte() subsequent_mask = torch.tril(subsequent_mask, out=subsequent_mask).unsqueeze(0) tgt_mask = tgt_mask & subsequent_mask # Create the source-target mask src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat( [1, ytime, 1]) out = self.pos_enc(self.embed(ys_in)) for l in range(self.n_layers): out, yy_aws, xy_aws = self.layers[l](out, tgt_mask, eouts, src_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws)) logits = self.output(self.norm_out(out)) # for knowledge distillation if return_logits: return logits # Compute XE sequence loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl
def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks before RNN layers xs, xlens = self.conv(xs, xlens) if not self.training: self.data_dict['elens'] = tensor2np(xlens) bs, xmax, idim = xs.size() xs = self.pos_enc(xs) if self.chunk_size_left > 0: # Time-restricted self-attention for streaming models cs_l = self.chunk_size_left cs_c = self.chunk_size_cur cs_r = self.chunk_size_right xs_chunks = [] xx_aws = [[] for l in range(self.n_layers)] xs_pad = torch.cat([ xs.new_zeros(bs, cs_l, idim), xs, xs.new_zeros(bs, cs_r, idim) ], dim=1) # TODO: remove right padding for t in range(cs_l, cs_l + xmax, self.chunk_size_cur): xs_chunk = xs_pad[:, t - cs_l:t + cs_c + cs_r] for l, layer in enumerate(self.layers): xs_chunk, xx_aws_chunk = layer(xs_chunk, None) # no mask xx_aws[l].append(xx_aws_chunk[:, :, cs_l:cs_l + cs_c, cs_l:cs_l + cs_c]) xs_chunks.append(xs_chunk[:, cs_l:cs_l + cs_c]) xs = torch.cat(xs_chunks, dim=1)[:, :xmax] if not self.training: for l in range(self.n_layers): self.aws_dict['xx_aws_layer%d' % l] = tensor2np( torch.cat(xx_aws[l], dim=3)[:, :, :xmax, :xmax]) else: # Create the self-attention mask xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) for l, layer in enumerate(self.layers): xs, xx_aws = layer(xs, xx_mask) if not self.training: self.aws_dict['xx_aws_layer%d' % l] = tensor2np(xx_aws) # Pick up outputs in the sub task before the projection layer if l == self.n_layers_sub1 - 1: xs_sub1 = self.layer_sub1( xs, xx_mask )[0] if self.task_specific_layer else xs.clone() xs_sub1 = self.norm_out_sub1(xs_sub1) if self.bridge_sub1 is not None: xs_sub1 = self.bridge_sub1(xs_sub1) if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if l == self.n_layers_sub2 - 1: xs_sub2 = self.layer_sub2( xs, xx_mask )[0] if self.task_specific_layer else xs.clone() xs_sub2 = self.norm_out_sub2(xs_sub2) if self.bridge_sub2 is not None: xs_sub2 = self.bridge_sub2(xs_sub2) if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def forward(self, xs, xlens, task, streaming=False, lookback=False, lookahead=False): """Forward pass. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (InteTensor): `[B]` (on CPU) task (str): ys/ys_sub1/ys_sub2 streaming (bool): streaming encoding lookback (bool): truncate leftmost frames for lookback in CNN context lookahead (bool): truncate rightmost frames for lookahead in CNN context Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (InteTensor): `[B]` (on CPU) """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } N_l = self.chunk_size_left N_c = self.chunk_size_current N_r = self.chunk_size_right bs, xmax, idim = xs.size() n_chunks = 0 clamp_len = self.clamp_len if self.latency_controlled: if self.streaming_type == 'reshape': xs = chunkwise(xs, N_l, N_c, N_r) # `[B * n_chunks, N_l+N_c+N_r, idim]` elif self.streaming_type == 'mask': # xs = chunkwise(xs, N_l, N_c, N_r) # `[B * n_chunks, N_l+N_c+N_r, idim]` xs = chunkwise(xs, 0, N_c, 0) # `[B * n_chunks, N_c, idim]` n_chunks = xs.size(0) // bs if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens) N_l = max(0, N_l // self.conv.subsampling_factor) N_c = N_c // self.conv.subsampling_factor N_r = N_r // self.conv.subsampling_factor clamp_len = clamp_len // self.conv.subsampling_factor if self.streaming_type == 'mask': # Extract the center region emax = xlens.max().item() xs = xs.contiguous().view( bs, -1, xs.size(2))[:, :emax] # `[B, emax, d_model]` if self.latency_controlled: # streaming Transformer encoder emax = xlens.max().item() pos_embs = None if self.pe_type in ['relative', 'relative_xl']: xs = xs * self.scale pos_embs = self.pos_emb(xs, zero_center_offset=True ) # NOTE: no clamp_len for streaming else: xs = self.pos_enc(xs, scale=True) if self.streaming_type == 'reshape': xx_mask_first = None xx_mask = None # NOTE: no mask to avoid masking all frames in a chunk elif self.streaming_type == 'mask': xx_mask_first, xx_mask = time_restricted_mask( xs, xlens, N_l, N_c, N_r, n_chunks) for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask if lth >= 1 else xx_mask_first, pos_embs=pos_embs, u_bias=self.u_bias, v_bias=self.v_bias) if not self.training: if self.streaming_type == 'reshape': n_heads = layer.xx_aws.size(1) xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c, N_l:N_l + N_c] xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c) xx_aws_center = xx_aws.new_zeros( bs, n_heads, emax, emax) for chunk_idx in range(n_chunks): offset = chunk_idx * N_c emax_chunk = xx_aws_center[:, :, offset:offset + N_c].size(2) xx_aws_chunk = xx_aws[:, chunk_idx, :, : emax_chunk, :emax_chunk] xx_aws_center[:, :, offset:offset + N_c, offset:offset + N_c] = xx_aws_chunk self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws_center) elif self.streaming_type == 'mask': self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) emax = xlens.max().item() N_l = max(0, N_l // self.subsample[lth].subsampling_factor) N_c = N_c // self.subsample[lth].subsampling_factor N_r = N_r // self.subsample[lth].subsampling_factor if self.pe_type in ['relative', 'relative_xl']: # Create sinusoidal positional embeddings for relative positional encoding pos_embs = self.pos_emb( xs, zero_center_offset=True ) # NOTE: no clamp_len for streaming if self.streaming_type == 'mask': _, xx_mask = time_restricted_mask( xs, xlens, N_l, N_c, N_r, n_chunks) # Extract the center region if self.streaming_type == 'reshape': xs = xs[:, N_l:N_l + N_c] # `[B * n_chunks, N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] else: if self.pe_type in ['relative', 'relative_xl']: xs = xs * self.scale # Create sinusoidal positional embeddings for relative positional encoding pos_embs = self.pos_emb(xs, clamp_len=clamp_len, zero_center_offset=True) else: xs = self.pos_enc(xs, scale=True) pos_embs = None # Create the self-attention mask xx_mask = make_pad_mask(xlens.to(self.device)).unsqueeze(1).repeat( [1, xs.size(1), 1]) for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs, u_bias=self.u_bias, v_bias=self.v_bias) if not self.training: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.sub_module(xs, xx_mask, lth, pos_embs, 'sub1') if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.sub_module(xs, xx_mask, lth, pos_embs, 'sub2') if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) # Create the self-attention mask xx_mask = make_pad_mask(xlens.to( self.device)).unsqueeze(1).repeat([1, xs.size(1), 1]) if self.pe_type in ['relative', 'relative_xl']: # Create sinusoidal positional embeddings for relative positional encoding clamp_len = clamp_len // self.subsample[ lth].subsampling_factor pos_embs = self.pos_emb(xs, clamp_len=clamp_len, zero_center_offset=True) xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def forward(self, eouts, elens, ylens=None, mode='parallel'): """Forward pass. Args: eouts (FloatTensor): `[B, T, enc_dim]` elens (IntTensor): `[B]` ylens (IntTensor): `[B]` mode (str): parallel/incremental Returns: cv (FloatTensor): `[B, L, enc_dim]` alpha (FloatTensor): `[B, T]` aws (FloatTensor): `[B, L, T]` """ bs, xmax, enc_dim = eouts.size() # 1d conv conv_feat = self.conv1d(eouts.transpose(2, 1)).transpose( 2, 1) # `[B, T, enc_dim]` conv_feat = torch.relu(self.norm(conv_feat)) alpha = torch.sigmoid(self.proj(conv_feat)).squeeze(2) # `[B, T]` # normalization if mode == 'parallel': # padding assert ylens is not None device = eouts.device ylens = ylens.to(device) mask = make_pad_mask(elens.to(device)) alpha = alpha.clone().masked_fill_(mask == 0, 0) alpha_norm = alpha / alpha.sum( 1, keepdim=True) * ylens.float().unsqueeze(1) ymax = int(ylens.max().item()) elif mode == 'incremental': alpha_norm = alpha # infernece time ymax = 1 if bs > 1: raise NotImplementedError('Batch mode is not supported.') # TODO(hirofumi0810): support batch mode else: raise ValueError(mode) cv = eouts.new_zeros(bs, ymax + 1, enc_dim) aws = eouts.new_zeros(bs, ymax + 1, xmax) n_tokens = torch.zeros(bs, dtype=torch.int64) state = eouts.new_zeros(bs, self.enc_dim) alpha_accum = eouts.new_zeros(bs) for j in range(xmax): alpha_accum_prev = alpha_accum alpha_accum += alpha_norm[:, j] if mode == 'parallel' and (alpha_accum >= self.beta).sum() == 0: # No boundary is located in all utterances in mini-batch # Carry over to the next frame state += alpha_norm[:, j, None] * eouts[:, j] aws[:, n_tokens, j] += alpha_norm[:, j] else: for b in range(bs): # skip the padding region if j > elens[b] - 1: continue # skip all-fired utterance if mode == 'parallel' and n_tokens[b].item() >= ylens[b]: continue if alpha_accum[b] < self.beta: # No boundary is located # Carry over to the next frame state[b] += alpha_norm[b, j, None] * eouts[b, j] aws[b, n_tokens[b], j] += alpha_norm[b, j] # tail handling if mode == 'incremental' and j == elens[b] - 1: if alpha_accum[b] >= 0.5: n_tokens[b] += 1 cv[b, n_tokens[b]] = state[b] break else: # A boundary is located ak1 = 1 - alpha_accum_prev[b] ak2 = alpha_norm[b, j] - ak1 cv[b, n_tokens[b]] = state[b] + ak1 * eouts[b, j] aws[b, n_tokens[b], j] += ak1 n_tokens[b] += 1 # Carry over to the next frame state[b] = ak2 * eouts[b, j] alpha_accum[b] = ak2 aws[b, n_tokens[b], j] += ak2 if mode == 'incremental': break if mode == 'incremental' and n_tokens[0] >= 1: break # TODO(hirofumi0810): support batch mode # truncate cv = cv[:, :ymax] aws = aws[:, :ymax] return cv, alpha, aws
def align(self, logits, elens, ys, ylens): bs, xmax, vocab = logits.size() # zero padding device_id = torch.cuda.device_of(logits).idx mask = make_pad_mask(elens, device_id) mask = mask.unsqueeze(2).repeat([1, 1, vocab]) logits = logits.masked_fill_(mask == 0, self.log0) log_probs = torch.log_softmax(logits, dim=-1).transpose(0, 1) # `[T, B, vocab]` path = _label_to_path(ys, self.blank) path_lens = 2 * ylens.long() + 1 ymax = ys.size(1) max_path_len = path.size(1) assert ys.size() == (bs, ymax), ys.size() assert path.size() == (bs, ymax * 2 + 1) alpha = log_probs.new_zeros(bs, max_path_len).fill_(self.log0) alpha[:, 0] = LOG_1 beta = alpha.clone() gamma = alpha.clone() batch_index = torch.arange(bs, dtype=torch.int64).unsqueeze(1) seq_index = torch.arange(xmax, dtype=torch.int64).unsqueeze(1).unsqueeze(2) log_probs_fwd_bwd = log_probs[seq_index, batch_index, path] # forward algorithm for t in range(xmax): alpha = self._computes_transition(alpha, path, path_lens, log_probs_fwd_bwd[t], log_probs[t]) # backward algorithm r_path = _flip_path(path, path_lens) log_probs_inv = _flip_label_probability( log_probs, elens.long()) # `[T, B, vocab]` log_probs_fwd_bwd = _flip_path_probability( log_probs_fwd_bwd, elens.long(), path_lens) # `[T, B, 2*L+1]` for t in range(xmax): beta = self._computes_transition(beta, r_path, path_lens, log_probs_fwd_bwd[t], log_probs_inv[t]) # pick up the best CTC path best_lattices = log_probs.new_zeros((bs, xmax), dtype=torch.int64) # forward algorithm log_probs_fwd_bwd = _flip_path_probability(log_probs_fwd_bwd, elens.long(), path_lens) for t in range(xmax): gamma = self._computes_transition(gamma, path, path_lens, log_probs_fwd_bwd[t], log_probs[t], skip_accum=True) # select paths where gamma is valid log_probs_fwd_bwd[t] = log_probs_fwd_bwd[t].masked_fill_( gamma == self.log0, self.log0) # pick up the best lattice offsets = log_probs_fwd_bwd[t].argmax(1) for b in range(bs): if t <= elens[b] - 1: token_idx = path[b, offsets[b]] best_lattices[b, t] = token_idx # remove the rest of paths gamma = log_probs.new_zeros(bs, max_path_len).fill_(self.log0) for b in range(bs): gamma[b, offsets[b]] = LOG_1 # pick up trigger points trigger_lattices = torch.zeros((bs, xmax), dtype=torch.int64) trigger_points = log_probs.new_zeros((bs, ymax + 1), dtype=torch.int32) # +1 for <eos> for b in range(bs): n_triggers = 0 trigger_points[b, ylens[b]] = elens[b] - 1 # for <eos> for t in range(elens[b]): token_idx = best_lattices[b, t] if token_idx == self.blank: continue if not (t == 0 or token_idx != best_lattices[b, t - 1]): continue # NOTE: select the most left trigger points trigger_lattices[b, t] = token_idx trigger_points[b, n_triggers] = t n_triggers += 1 # print(trigger_points[0]) # print(trigger_lattices[0]) # print(ys[0]) assert ylens.sum() == (trigger_lattices != 0).sum() return trigger_points
def forward_att(self, eouts, elens, ys, trigger_points=None): """Compute XE loss for the Transformer decoder. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (List): length `[B]`, each of which contains a list of size `[L]` trigger_points (IntTensor): `[B, L]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity losses_auxiliary (dict): """ losses_auxiliary = {} # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(ys, self.eos, self.eos, self.pad, self.device, self.bwd) if not self.training: self.data_dict['elens'] = tensor2np(elens) self.data_dict['ylens'] = tensor2np(ylens) self.data_dict['ys'] = tensor2np(ys_out) # Create target self-attention mask bs, ymax = ys_in.size()[:2] tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1]) causal_mask = tgt_mask.new_ones(ymax, ymax, dtype=tgt_mask.dtype) causal_mask = torch.tril(causal_mask).unsqueeze(0) tgt_mask = tgt_mask & causal_mask # `[B, L (query), L (key)]` # Create source-target mask src_mask = make_pad_mask(elens.to(self.device)).unsqueeze(1).repeat([1, ymax, 1]) # `[B, L, T]` # Create attention padding mask for quantity loss if self.attn_type == 'mocha': attn_mask = (ys_out != self.pad).unsqueeze(1).unsqueeze(3) # `[B, 1, L, 1]` else: attn_mask = None # external LM integration lmout = None if self.lm is not None: self.lm.eval() with torch.no_grad(): lmout, lmstate, _ = self.lm.predict(ys_in, None) lmout = self.lm_output_proj(lmout) out = self.pos_enc(self.embed_token_id(ys_in), scale=True) # scaled + dropout xy_aws_layers = [] xy_aws = None for lth, layer in enumerate(self.layers): out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout) # Attention padding xy_aws = layer.xy_aws if xy_aws is not None and self.attn_type == 'mocha': xy_aws_masked = xy_aws.masked_fill_(attn_mask.expand_as(xy_aws) == 0, 0) # NOTE: attention padding is quite effective for quantity loss xy_aws_layers.append(xy_aws_masked.clone()) if not self.training: self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws) self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws) self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta) self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose) self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm) logits = self.output(self.norm_out(out)) # Compute XE loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) # Quantity loss losses_auxiliary['loss_quantity'] = 0. if self.attn_type == 'mocha': # Average over all heads across all layers n_tokens_ref = tgt_mask[:, -1, :].sum(1).float() # `[B]` # NOTE: count <eos> tokens n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1)) for aws in xy_aws_layers]) # `[B]` n_tokens_pred /= len(xy_aws_layers) losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref)) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl, losses_auxiliary
def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks before RNN layers xs, xlens = self.conv(xs, xlens) bs, xmax, idim = xs.size() xs = self.pos_enc(xs) if self.chunk_size_left > 0: # Time-restricted self-attention for streaming models cs_l = self.chunk_size_left cs_c = self.chunk_size_current cs_r = self.chunk_size_right hop_size = self.chunk_size_current xs_chunks = [] xx_aws = [[] for l in range(self.n_layers)] xs_pad = torch.cat([ xs.new_zeros(bs, cs_l, idim), xs, xs.new_zeros(bs, cs_r, idim) ], dim=1) # TODO: remove right padding for t in range(cs_l, cs_l + xmax, hop_size): xs_chunk = xs_pad[:, t - cs_l:t + cs_c + cs_r] for l in range(self.n_layers): xs_chunk, xx_aws_chunk = self.layers[l](xs_chunk, None) # no mask xx_aws[l].append(xx_aws_chunk[:, :, cs_l:cs_l + cs_c, cs_l:cs_l + cs_c]) xs_chunks.append(xs_chunk[:, cs_l:cs_l + cs_c]) xs = torch.cat(xs_chunks, dim=1)[:, :xmax] if not self.training: for l in range(self.n_layers): setattr( self, 'xx_aws_layer%d' % l, tensor2np( torch.cat(xx_aws[l], dim=3)[:, :, :xmax, :xmax])) else: # Create the self-attention mask xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) for l in range(self.n_layers): xs, xx_aws = self.layers[l](xs, xx_mask) if not self.training: setattr(self, 'xx_aws_layer%d' % l, tensor2np(xx_aws)) xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) eouts['ys']['xs'] = xs eouts['ys']['xlens'] = xlens return eouts