def forward(self, xs, xlens, task, streaming=False, lookback=False, lookahead=False): """Forward pass. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (InteTensor): `[B]` (on CPU) task (str): ys/ys_sub1/ys_sub2 streaming (bool): streaming encoding lookback (bool): truncate leftmost frames for lookback in CNN context lookahead (bool): truncate rightmost frames for lookahead in CNN context Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (InteTensor): `[B]` (on CPU) """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } bs, xmax = xs.size()[:2] n_chunks = 0 unidir = self.unidir lc_bidir = self.lc_bidir N_l, N_c, N_r = self.chunk_size_left, self.chunk_size_current, self.chunk_size_right if streaming and self.streaming_type == 'mask': assert xmax <= N_c elif streaming and self.streaming_type == 'reshape': assert xmax <= (N_l + N_c + N_r) if lc_bidir: if self.streaming_type == 'mask' and not streaming: xs = chunkwise(xs, 0, N_c, 0, padding=True) # `[B * n_chunks, N_c, idim]` # NOTE: CNN consumes inputs in the current chunk to avoid extra lookahead latency # That is, CNN outputs are independent on chunk boundary elif self.streaming_type == 'reshape': xs = chunkwise(xs, N_l, N_c, N_r, padding=not streaming ) # `[B * n_chunks, N_l+N_c+N_r, idim]` n_chunks = xs.size(0) // bs assert bs * n_chunks == xs.size(0) if streaming: assert n_chunks == 1, xs.size() if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens, lookback=False if lc_bidir else lookback, lookahead=False if lc_bidir else lookahead) # NOTE: CNN lookahead surpassing a chunk is not allowed in chunkwise processing N_l = max(0, N_l // self.conv.subsampling_factor) N_c = N_c // self.conv.subsampling_factor N_r = N_r // self.conv.subsampling_factor if lc_bidir: # Do nothing in the streaming mode if self.streaming_type == 'mask' and not streaming: # back to the original shape (during training only) xs = xs.contiguous().view( bs, -1, xs.size(2))[:, :xlens.max()] # `[B, emax, d_model]` elif streaming: xs = xs[:, :xlens.max()] # for unidirectional if self.enc_type == 'conv': eouts['ys']['xs'] = xs eouts['ys']['xlens'] = xlens return eouts if not streaming: self.reset_cache() n_hist = self.cache[0]['input_san'].size( 1) if streaming and self.cache[0] is not None else 0 # positional encoding if self.pe_type in ['relative', 'relative_xl']: xs = xs * self.scale # NOTE: first layer only rel_pos_embs = self.pos_emb(xs, mlen=n_hist) else: xs = self.pos_enc(xs, scale=True, offset=max(0, n_hist)) rel_pos_embs = None new_cache = [None] * self.n_layers if lc_bidir: # chunkwise streaming encoder if self.streaming_type == 'reshape': xx_mask = None # NOTE: no mask to avoid masking all frames in a chunk elif self.streaming_type == 'mask': if streaming: n_chunks = math.ceil((xlens.max().item() + n_hist) / N_c) xx_mask = make_chunkwise_san_mask(xs, xlens + n_hist, N_l, N_c, n_chunks) for lth, layer in enumerate(self.layers): xs, cache = layer(xs, xx_mask, cache=self.cache[lth], pos_embs=rel_pos_embs, u_bias=self.u_bias, v_bias=self.v_bias) if self.streaming_type == 'mask': new_cache[lth] = cache if not self.training and not streaming: if self.streaming_type == 'reshape': n_heads = layer.xx_aws.size(1) xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c, N_l:N_l + N_c] xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c) emax = xlens.max().item() xx_aws_center = xx_aws.new_zeros( bs, n_heads, emax, emax) for chunk_idx in range(n_chunks): offset = chunk_idx * N_c emax_chunk = xx_aws_center[:, :, offset:offset + N_c].size(2) xx_aws_chunk = xx_aws[:, chunk_idx, :, : emax_chunk, :emax_chunk] xx_aws_center[:, :, offset:offset + N_c, offset:offset + N_c] = xx_aws_chunk self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws_center) elif self.streaming_type == 'mask': self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) N_l = max(0, N_l // self.subsample[lth].factor) N_c = N_c // self.subsample[lth].factor N_r = N_r // self.subsample[lth].factor if self.pe_type in ['relative', 'relative_xl']: rel_pos_embs = self.pos_emb(xs) if self.streaming_type == 'mask': xx_mask = make_chunkwise_san_mask( xs, xlens, N_l, N_c, n_chunks) # Extract the center region if self.streaming_type == 'reshape': xs = xs[:, N_l:N_l + N_c] # `[B * n_chunks, N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :xlens.max()] else: xx_mask = make_san_mask(xs, xlens + n_hist, unidir, self.lookaheads[0]) for lth, layer in enumerate(self.layers): xs, cache = layer(xs, xx_mask, cache=self.cache[lth], pos_embs=rel_pos_embs, u_bias=self.u_bias, v_bias=self.v_bias) new_cache[lth] = cache if not self.training and not streaming: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.sub_module(xs, xx_mask, lth, rel_pos_embs, 'sub1') xlens_sub1 = xlens.clone() if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens_sub1 return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.sub_module(xs, xx_mask, lth, rel_pos_embs, 'sub2') xlens_sub2 = xlens.clone() if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens_sub2 return eouts if lth < len(self.layers) - 1: if self.subsample is not None and self.subsample[ lth].factor > 1: xs, xlens = self.subsample[lth](xs, xlens) n_hist = self.cache[lth + 1]['input_san'].size( 1) if streaming and self.cache[ lth + 1] is not None else 0 if self.pe_type in ['relative', 'relative_xl']: rel_pos_embs = self.pos_emb(xs, mlen=n_hist) xx_mask = make_san_mask(xs, xlens + n_hist, unidir, self.lookaheads[lth + 1]) elif self.lookaheads[lth] != self.lookaheads[lth + 1]: xx_mask = make_san_mask(xs, xlens + n_hist, unidir, self.lookaheads[lth + 1]) xs = self.norm_out(xs) if streaming: self.cache = new_cache # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1'][ 'xlens'] = xs_sub1, xlens_sub1 if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2'][ 'xlens'] = xs_sub2, xlens_sub2 return eouts
def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } N_l = self.chunk_size_left N_c = self.chunk_size_current N_r = self.chunk_size_right bs, xmax, idim = xs.size() if self.latency_controlled: xs = chunkwise(xs, N_l, N_c, N_r) if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens) if not self.training: self.data_dict['elens'] = tensor2np(xlens) if self.latency_controlled: # streaming Conformer encoder _N_l = max(0, N_l // self.subsampling_factor) _N_c = N_c // self.subsampling_factor n_chunks = math.ceil(xs.size(0) / bs) emax = math.ceil(xmax / self.subsampling_factor) xs = xs * self.scale pos_idxs = torch.arange(xs.size(1) - 1, -1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) xx_mask = None # NOTE: no mask for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs) if not self.training: n_heads = layer.xx_aws.size(1) xx_aws = layer.xx_aws[:, :, _N_l:_N_l + _N_c, _N_l:_N_l + _N_c] xx_aws = xx_aws.view(bs, n_chunks, n_heads, _N_c, _N_c) xx_aws_center = xx_aws.new_zeros(bs, n_heads, emax, emax) for chunk_idx in range(n_chunks): offset = chunk_idx * _N_c emax_blc = xx_aws_center[:, :, offset:offset + _N_c].size(2) xx_aws_chunk = xx_aws[:, chunk_idx, :, :emax_blc, : emax_blc] xx_aws_center[:, :, offset:offset + _N_c, offset:offset + _N_c] = xx_aws_chunk self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws_center) # Extract the center region xs = xs[:, _N_l:_N_l + _N_c] # `[B * n_chunks, _N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] else: bs, xmax, idim = xs.size() xs = xs * self.scale # Create the self-attention mask xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) pos_idxs = torch.arange(xmax - 1, -1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs) if not self.training: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.layer_sub1( xs, xx_mask, pos_embs=pos_embs ) if self.task_specific_layer else xs.clone() xs_sub1 = self.norm_out_sub1(xs_sub1) if self.bridge_sub1 is not None: xs_sub1 = self.bridge_sub1(xs_sub1) if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.layer_sub2( xs, xx_mask, pos_embs=pos_embs ) if self.task_specific_layer else xs.clone() xs_sub2 = self.norm_out_sub2(xs_sub2) if self.bridge_sub2 is not None: xs_sub2 = self.bridge_sub2(xs_sub2) if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def forward(self, xs, xlens, task, streaming=False, lookback=False, lookahead=False): """Forward pass. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (InteTensor): `[B]` (on CPU) task (str): ys/ys_sub1/ys_sub2 streaming (bool): streaming encoding lookback (bool): truncate leftmost frames for lookback in CNN context lookahead (bool): truncate rightmost frames for lookahead in CNN context Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (InteTensor): `[B]` (on CPU) """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } N_l = self.chunk_size_left N_c = self.chunk_size_current N_r = self.chunk_size_right bs = xs.size(0) n_chunks = 0 clamp_len = self.clamp_len if self.latency_controlled: if self.streaming_type == 'reshape': xs = chunkwise(xs, N_l, N_c, N_r) # `[B * n_chunks, N_l+N_c+N_r, idim]` elif self.streaming_type == 'mask': # xs = chunkwise(xs, N_l, N_c, N_r) # `[B * n_chunks, N_l+N_c+N_r, idim]` xs = chunkwise(xs, 0, N_c, 0) # `[B * n_chunks, N_c, idim]` n_chunks = xs.size(0) // bs if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens) N_l = max(0, N_l // self.conv.subsampling_factor) N_c = N_c // self.conv.subsampling_factor N_r = N_r // self.conv.subsampling_factor clamp_len = clamp_len // self.conv.subsampling_factor if self.streaming_type == 'mask': # Extract the center region emax = xlens.max().item() xs = xs.contiguous().view( bs, -1, xs.size(2))[:, :emax] # `[B, emax, d_model]` if self.latency_controlled: # streaming encoder emax = xlens.max().item() pos_embs = None if self.pe_type in ['relative', 'relative_xl']: xs = xs * self.scale pos_embs = self.pos_emb(xs, zero_center_offset=True ) # NOTE: no clamp_len for streaming else: xs = self.pos_enc(xs, scale=True) if self.streaming_type == 'reshape': xx_mask_first = None xx_mask = None # NOTE: no mask to avoid masking all frames in a chunk elif self.streaming_type == 'mask': xx_mask_first, xx_mask = make_time_restricted_san_mask( xs, xlens, N_l, N_c, N_r, n_chunks) for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask if lth >= 1 else xx_mask_first, pos_embs=pos_embs, u_bias=self.u_bias, v_bias=self.v_bias) if not self.training: if self.streaming_type == 'reshape': n_heads = layer.xx_aws.size(1) xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c, N_l:N_l + N_c] xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c) xx_aws_center = xx_aws.new_zeros( bs, n_heads, emax, emax) for chunk_idx in range(n_chunks): offset = chunk_idx * N_c emax_chunk = xx_aws_center[:, :, offset:offset + N_c].size(2) xx_aws_chunk = xx_aws[:, chunk_idx, :, : emax_chunk, :emax_chunk] xx_aws_center[:, :, offset:offset + N_c, offset:offset + N_c] = xx_aws_chunk self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws_center) elif self.streaming_type == 'mask': self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) emax = xlens.max().item() N_l = max(0, N_l // self.subsample[lth].subsampling_factor) N_c = N_c // self.subsample[lth].subsampling_factor N_r = N_r // self.subsample[lth].subsampling_factor if self.pe_type in ['relative', 'relative_xl']: # Create sinusoidal positional embeddings for relative positional encoding pos_embs = self.pos_emb( xs, zero_center_offset=True ) # NOTE: no clamp_len for streaming if self.streaming_type == 'mask': _, xx_mask = make_time_restricted_san_mask( xs, xlens, N_l, N_c, N_r, n_chunks) # Extract the center region if self.streaming_type == 'reshape': xs = xs[:, N_l:N_l + N_c] # `[B * n_chunks, N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] else: if self.pe_type in ['relative', 'relative_xl']: xs = xs * self.scale # Create sinusoidal positional embeddings for relative positional encoding pos_embs = self.pos_emb(xs, clamp_len=clamp_len, zero_center_offset=True) else: xs = self.pos_enc(xs, scale=True) pos_embs = None xx_mask = make_san_mask(xs, xlens, self.unidirectional) for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs, u_bias=self.u_bias, v_bias=self.v_bias) if not self.training: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.sub_module(xs, xx_mask, lth, pos_embs, 'sub1') if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.sub_module(xs, xx_mask, lth, pos_embs, 'sub2') if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) xx_mask = make_san_mask(xs, xlens, self.unidirectional) if self.pe_type in ['relative', 'relative_xl']: # Create sinusoidal positional embeddings for relative positional encoding clamp_len = clamp_len // self.subsample[ lth].subsampling_factor pos_embs = self.pos_emb(xs, clamp_len=clamp_len, zero_center_offset=True) xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward pass. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (InteTensor): `[B]` (on CPU) task (str): ys/ys_sub1/ys_sub2 use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (InteTensor): `[B]` (on CPU) """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } N_l = self.chunk_size_left N_c = self.chunk_size_current N_r = self.chunk_size_right bs, xmax, idim = xs.size() n_chunks = 0 if self.latency_controlled: if self.lc_type == 'reshape': xs = chunkwise(xs, N_l, N_c, N_r) # `[B * n_chunks, N_l+N_c+N_r, idim]` elif self.lc_type == 'mask': # xs = chunkwise(xs, N_l, N_c, N_r) # `[B * n_chunks, N_l+N_c+N_r, idim]` xs = chunkwise(xs, 0, N_c, 0) # `[B * n_chunks, N_c, idim]` else: raise ValueError n_chunks = xs.size(0) // bs if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens) N_l = max(0, N_l // self.conv.subsampling_factor) N_c = N_c // self.conv.subsampling_factor N_r = N_r // self.conv.subsampling_factor if self.lc_type == 'mask': # Extract the center region emax = xlens.max().item() # xs = xs[:, N_l:N_l + N_c] # `[B * n_chunks, N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] # `[B, emax, d_model]` if self.latency_controlled: # streaming Transformer encoder emax = xlens.max().item() pos_embs = None if self.pe_type == 'relative': xs = xs * self.scale pos_idxs = torch.arange(xs.size(1) - 1, -1, -1.0, dtype=torch.float, device=self.device) pos_embs = self.pos_emb(pos_idxs) else: xs = self.pos_enc(xs, scale=True) xx_mask = None # NOTE: no mask to avoid all masked region for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs) if not self.training: if self.lc_type == 'reshape': n_heads = layer.xx_aws.size(1) xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c, N_l:N_l + N_c] xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c) xx_aws_center = xx_aws.new_zeros( bs, n_heads, emax, emax) for chunk_idx in range(n_chunks): offset = chunk_idx * N_c emax_chunk = xx_aws_center[:, :, offset:offset + N_c].size(2) xx_aws_chunk = xx_aws[:, chunk_idx, :, : emax_chunk, :emax_chunk] xx_aws_center[:, :, offset:offset + N_c, offset:offset + N_c] = xx_aws_chunk elif self.lc_type == 'mask': self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) else: raise ValueError self.data_dict['elens%d' % lth] = tensor2np(xlens) if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) emax = xlens.max().item() N_l = max(0, N_l // self.subsample[lth].subsampling_factor) N_c = N_c // self.subsample[lth].subsampling_factor N_r = N_r // self.subsample[lth].subsampling_factor if self.lc_type == 'mask': xx_mask = make_pad_mask(xlens.to(self.device)) xx_mask = xx_mask.unsqueeze(1).repeat( [1, xs.size(1), 1]) # `[B, emax (query), emax (key)]` for chunk_idx in range(n_chunks): offset = chunk_idx * N_c xx_mask[:, offset:offset + N_c, :max(0, offset - N_l)] = 0 xx_mask[:, offset:offset + N_c, offset + (N_c + N_r):] = 0 # Extract the center region if self.lc_type == 'reshape': xs = xs[:, N_l:N_l + N_c] # `[B * n_chunks, N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :emax] else: if self.pe_type == 'relative': xs = xs * self.scale # Create sinusoidal positional embeddings for relative positional encoding pos_idxs = torch.arange(xs.size(1) - 1, -1, -1.0, dtype=torch.float, device=self.device) pos_embs = self.pos_emb(pos_idxs) else: xs = self.pos_enc(xs, scale=True) pos_embs = None # Create the self-attention mask xx_mask = make_pad_mask(xlens.to(self.device)).unsqueeze(1).repeat( [1, xs.size(1), 1]) for lth, layer in enumerate(self.layers): xs = layer(xs, xx_mask, pos_embs=pos_embs) if not self.training: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.sub_module(xs, xx_mask, lth, pos_embs, 'sub1') if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.sub_module(xs, xx_mask, lth, pos_embs, 'sub2') if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) # Create the self-attention mask xx_mask = make_pad_mask(xlens.to( self.device)).unsqueeze(1).repeat([1, xs.size(1), 1]) if self.pe_type == 'relative': # Create sinusoidal positional embeddings for relative positional encoding pos_idxs = torch.arange(xs.size(1) - 1, -1, -1.0, dtype=torch.float, device=self.device) pos_embs = self.pos_emb(pos_idxs) xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts
def forward(self, xs, xlens, task, streaming=False, lookback=False, lookahead=False): """Forward pass. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): A list of length `[B]` task (str): all or ys or ys_sub1 or ys_sub2 streaming (bool): streaming encoding lookback (bool): truncate leftmost frames for lookback in CNN context lookahead (bool): truncate rightmost frames for lookahead in CNN context Returns: eouts (dict): xs (FloatTensor): `[B, T // prod(subsample), n_units (*2)]` xlens (IntTensor): `[B]` xs_sub1 (FloatTensor): `[B, T // prod(subsample), n_units (*2)]` xlens_sub1 (IntTensor): `[B]` xs_sub2 (FloatTensor): `[B, T // prod(subsample), n_units (*2)]` xlens_sub2 (IntTensor): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } # Sort by lengths in the descending order for pack_padded_sequence perm_ids_unsort = None if not self.lc_bidir: xlens, perm_ids = torch.IntTensor(xlens).sort(0, descending=True) xs = xs[perm_ids] _, perm_ids_unsort = perm_ids.sort() # Dropout for inputs-hidden connection xs = self.dropout_in(xs) bs, xmax, idim = xs.size() N_c, N_r = self.N_c, self.N_r if self.lc_bidir and not self.cnn_lookahead: xs = chunkwise(xs, 0, N_c, 0) # `[B * n_chunks, N_c, idim]` # Extract the center region xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :xlens.max()] # `[B, emax, d_model]` # Path through CNN blocks before RNN layers if self.conv is not None: xs, xlens = self.conv(xs, xlens, lookback=lookback, lookahead=lookahead) if self.enc_type == 'conv': eouts['ys']['xs'] = xs eouts['ys']['xlens'] = xlens return eouts if self.lc_bidir: N_c = N_c // self.conv_factor N_r = N_r // self.conv_factor carry_over = self.rsp_prob > 0 and self.training and random.random( ) < self.rsp_prob carry_over = carry_over and (bs == (self.hx_fwd[0][0].size(0) if self.hx_fwd[0] is not None else 0)) if not streaming and not carry_over: self.reset_cache() # NOTE: do not reset here for streaming inference if self.lc_bidir: # Flip the layer and time loop if self.N_c <= 0: xs, xlens, xs_sub1, xlens_sub1 = self._forward_full_context( xs, xlens) else: xs, xlens, xs_sub1, xlens_sub1 = self._forward_latency_controlled( xs, xlens, N_c, N_r, streaming) if task == 'ys_sub1': eouts[task]['xs'], eouts[task]['xlens'] = xs_sub1, xlens_sub1 return eouts else: for lth in range(self.n_layers): self.rnn[lth].flatten_parameters() # for multi-GPUs xs, state = self.padding(xs, xlens, self.rnn[lth], prev_state=self.hx_fwd[lth], streaming=streaming) self.hx_fwd[lth] = state xs = self.dropout(xs) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1, xlens_sub1 = self.sub_module( xs, xlens, perm_ids_unsort, 'sub1') if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens_sub1 return eouts if lth == self.n_layers_sub2 - 1: xs_sub2, xlens_sub2 = self.sub_module( xs, xlens, perm_ids_unsort, 'sub2') if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens_sub2 return eouts # Projection layer if self.proj is not None and lth != self.n_layers - 1: xs = torch.relu(self.proj[lth](xs)) # Subsampling layer if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) xs = xs[:, :xlens.max()] if task in ['all', 'ys']: if perm_ids_unsort is not None: xs = xs[perm_ids_unsort] xlens = xlens[perm_ids_unsort] eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1'][ 'xlens'] = xs_sub1, xlens_sub1 if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2'][ 'xlens'] = xs_sub2, xlens_sub2 return eouts