def forward( self, src_tokens, src_lengths, z_src_tokens, z_src_lengths, prev_output_tokens, cls_input: Optional[Tensor] = None, return_all_hiddens: bool = True, features_only: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): """ Run the forward pass for an encoder-decoder model. Copied from the base class, but without ``**kwargs``, which are not supported by TorchScript. """ encoder_out = self.encoder( src_tokens, src_lengths=src_lengths, cls_input=cls_input, return_all_hiddens=return_all_hiddens, ) x = self.f1(encoder_out.encoder_out, encoder_out.encoder_padding_mask) encoder_out = EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_out.encoder_padding_mask, # B x T encoder_embedding=encoder_out.encoder_embedding, # B x T x C encoder_states=encoder_out.encoder_states, # List[T x B x C] ) z_encoder_out = self.encoder( z_src_tokens, src_lengths=z_src_lengths, cls_input=cls_input, return_all_hiddens=return_all_hiddens, ) x = self.f2(z_encoder_out.encoder_out, z_encoder_out.encoder_padding_mask) z_encoder_out = EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=z_encoder_out.encoder_padding_mask, # B x T encoder_embedding=z_encoder_out.encoder_embedding, # B x T x C encoder_states=z_encoder_out.encoder_states, # List[T x B x C] ) decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, features_only=features_only, alignment_layer=alignment_layer, alignment_heads=alignment_heads, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens, z_encoder_out=z_encoder_out, z_src_lengths=z_src_lengths, ) return decoder_out
def forward(self, src_tokens, src_lengths): x, input_lengths = self.subsample(src_tokens, src_lengths) x = self.embed_scale * x encoder_padding_mask = lengths_to_padding_mask(input_lengths) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) x += positions x = self.dropout_module(x) for layer in self.transformer_layers: x = layer(x, encoder_padding_mask) if not encoder_padding_mask.any(): encoder_padding_mask = None if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, encoder_padding_mask=encoder_padding_mask, encoder_embedding=None, encoder_states=None, src_tokens=None, src_lengths=None, )
def forward(self, src_tokens, src_lengths=None, **kwargs): return EncoderOut( encoder_out=src_tokens, encoder_padding_mask=None, encoder_embedding=None, encoder_states=None, )
def generate(self, models, sample, **unused): """Generate a batch of inferences. EncoderOut( encoder_out=encoder_out['encoder_out'], # T x B x C encoder_embedding=None, encoder_padding_mask=encoder_out['encoder_padding_mask'], # B x T encoder_states=None, src_tokens=None, src_lengths=None, ) """ encoder_output = models[0].get_encoder_output(sample['net_input']) encoder_out = { "encoder_out": encoder_output.encoder_out.transpose(0, 1), # B x T x C "padding_mask": encoder_output.encoder_padding_mask } alphas, _ = models[0].assigner(encoder_out) # _alphas, num_output = self.resize(alphas, kwargs['target_lengths'], at_least_one=True) cif_outputs = models[0].cif(encoder_out, alphas) src_lengths = torch.round(alphas.sum(-1)).int() self.step_forward_fn = models[0].decode encoder_output = EncoderOut( encoder_out=cif_outputs.transpose(0, 1), # T x B x C encoder_embedding=None, encoder_padding_mask=~utils.sequence_mask( src_lengths, dtype=torch.bool), # B x T encoder_states=None, src_tokens=None, src_lengths=src_lengths, ) return self.decode(encoder_output)
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ if self.layer_wise_attention: return_all_hiddens = True x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None position_bias = None if self.rel_pos: seq_len, bsz, _ = x.size() position_bias = self.compute_bias(seq_len, seq_len) # (1, n_heads, qlen, klen) position_bias = position_bias.repeat(bsz, 1, 1, 1) position_bias = position_bias.view(bsz * self.num_heads, seq_len, seq_len) # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask, position_bias=position_bias) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, )
def reorder_encoder_out(self, encoder_out, new_order): return EncoderOut( encoder_out=encoder_out.encoder_out.index_select(0, new_order), encoder_padding_mask=None, encoder_embedding=None, encoder_states=None, )
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False, return_all_attn: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). return_all_attn (bool, optional): also return all of the intermediate layers' attention weights (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. - **encoder_attn** (List[Tensor]): all intermediate layers' attention weights of shape `(num_heads, batch, src_len, src_len)`. Only populated if *return_all_attn* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None encoder_attn = [] if return_all_attn else None # encoder layers for layer in self.layers: x, attn = layer(x, encoder_padding_mask, need_head_weights=return_all_attn) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if return_all_attn and attn is not None: assert encoder_attn is not None encoder_attn.append(attn) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] encoder_attn=encoder_attn, # List[N x B x T x T] src_tokens=None, src_lengths=None, )
def reorder_encoder_out(self, encoder_out, new_order): """ if self.beam_size < 0: self.beam_size = int(new_order.shape[0] / self.batch_size) else: new_order = new_order // self.beam_size new_order = new_order[:: self.beam_size] new_encoder_out = encoder_out.encoder_out.index_select(1, new_order) new_encoder_padding_mask = encoder_out.encoder_padding_mask.index_select( 0, new_order ) """ new_encoder_out = encoder_out.encoder_out.index_select(1, new_order) new_encoder_padding_mask = encoder_out.encoder_padding_mask.index_select( 0, new_order ) return EncoderOut( encoder_out=new_encoder_out, # T x B x C encoder_padding_mask=new_encoder_padding_mask, # B x T encoder_embedding=None, # B x T x C encoder_states=None, src_tokens=None, src_lengths=None, )
def forward(self, src_tokens, cluster_ids, src_lengths, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ # print(src_tokens) x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: if isinstance(layer, TransformerClusterEncoderLayer): x = layer(x, encoder_padding_mask, int(cluster_ids[0])) else: x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, )
def forward(self, src_videos, src_lengths=None, **kwargs): x = self.cnn(src_videos.transpose(1, 2).contiguous()) # B x C x T x = x.transpose(1, 2).contiguous().transpose(0, 1) # T X B X C return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=None, # B x T encoder_embedding=None, # B x T x C encoder_states=None, # List[T x B x C] )
def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ """ Since encoder_padding_mask and encoder_embedding are both of type Optional[Tensor] in EncoderOut, they need to be copied as local variables for Torchscript Optional refinement """ encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding new_encoder_out = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order) ) new_encoder_padding_mask = ( encoder_padding_mask if encoder_padding_mask is None else encoder_padding_mask.index_select(0, new_order) ) new_encoder_embedding = ( encoder_embedding if encoder_embedding is None else encoder_embedding.index_select(0, new_order) ) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return EncoderOut( encoder_out=new_encoder_out, # T x B x C encoder_padding_mask=new_encoder_padding_mask, # B x T encoder_embedding=new_encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 )
def forward(self, src_tokens, src_lengths: Tensor, **unused): if self.left_pad: # nn.utils.rnn.pack_padded_sequence requires right-padding; # convert left-padding to right-padding src_tokens = speech_utils.convert_padding_direction( src_tokens, src_lengths, left_to_right=True, ) if self.conv_layers_before is not None: x, src_lengths, padding_mask = self.conv_layers_before( src_tokens, src_lengths) else: x, padding_mask = src_tokens, \ ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) bsz, seqlen = x.size(0), x.size(1) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) state_size = 2 if self.bidirectional else 1, bsz, self.hidden_size h0, c0 = x.new_zeros(*state_size), x.new_zeros(*state_size) for i in range(len(self.lstm)): if self.residual and i > 0: # residual connection starts from the 2nd layer prev_x = x # pack embedded source tokens into a PackedSequence packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data) # apply LSTM packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0)) # unpack outputs and apply dropout x, _ = nn.utils.rnn.pad_packed_sequence( packed_outs, padding_value=self.padding_value * 1.0) if i < len( self.lstm) - 1: # not applying dropout for the last layer x = F.dropout(x, p=self.dropout_out, training=self.training) x = x + prev_x if self.residual and i > 0 else x assert list(x.size()) == [seqlen, bsz, self.output_units] encoder_padding_mask = padding_mask.t() return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B encoder_embedding=None, encoder_states=None, src_tokens=None, src_lengths=src_lengths, # B )
def get_encoder_output(self, net_input): encoder_out = self.encoder(tbc=True, **net_input) return EncoderOut( encoder_out=encoder_out['encoder_out'], # T x B x C encoder_embedding=None, encoder_padding_mask=encoder_out['encoder_padding_mask'], # B x T encoder_states=None, src_tokens=None, src_lengths=None, )
def forward(self, src_tokens, src_lengths, **kwargs): d = super().forward(c, src_lengths, **kwargs) epm = d.get('encoder_padding_mask', None) epm = epm.t() if epm is not None else None return EncoderOut( encoder_out=d['encoder_out'], # T x B x C encoder_padding_mask=epm, # B x T encoder_embedding=None, # B x T x C encoder_states=None, # List[T x B x C] )
def forward(self, src_tokens, src_lengths=None, **kwargs): assert "fancy_other_input" in kwargs assert kwargs["fancy_other_input"] is not None return EncoderOut( encoder_out=src_tokens, encoder_padding_mask=None, encoder_embedding=None, encoder_states=None, src_tokens=None, src_lengths=None, )
def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): encoder_padding_mask = encoder_out.encoder_padding_mask.index_select(1, new_order) \ if encoder_out.encoder_padding_mask is not None else None return EncoderOut( encoder_out=encoder_out.encoder_out.index_select(1, new_order), encoder_padding_mask=encoder_padding_mask, encoder_embedding=None, encoder_states=None, src_tokens=None, src_lengths=encoder_out.src_lengths.index_select(0, new_order), )
def forward( self, src_tokens, src_lengths, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ out = super().forward(src_tokens, src_lengths, return_all_hiddens=return_all_hiddens) x, x_lengths = out.encoder_out, out.src_lengths # determine which output frame to select for loss evaluation/test, assuming # all examples in a batch are of the same length for chunk-wise training/test if (self.out_chunk_end is not None and (self.training or not self.training_stage)): x = x[self.out_chunk_begin: self.out_chunk_end] # T x B x C -> W x B x C x_lengths = x_lengths.fill_(x.size(0)) if self.fc_out is not None: x = self.fc_out(x) # T x B x C -> T x B x V return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=out.encoder_padding_mask.transpose( 0, 1), # T x B encoder_embedding=out.encoder_embedding, # None encoder_states=out.encoder_states, # List[T x B x C] src_tokens=out.src_tokens, # None src_lengths=x_lengths, # B )
def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ new_encoder_out: Dict[str, Tensor] = {} new_encoder_out["encoder_out"] = ( encoder_out.encoder_out if encoder_out.encoder_out is None else encoder_out.encoder_out.index_select(1, new_order)) new_encoder_out["encoder_padding_mask"] = ( encoder_out.encoder_padding_mask if encoder_out.encoder_padding_mask is None else encoder_out.encoder_padding_mask.index_select(0, new_order)) new_encoder_out["encoder_embedding"] = ( encoder_out.encoder_embedding if encoder_out.encoder_embedding is None else encoder_out.encoder_embedding.index_select(0, new_order)) src_tokens = encoder_out.src_tokens if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) src_lengths = encoder_out.src_lengths if src_lengths is not None: src_lengths = src_lengths.index_select(0, new_order) encoder_states = encoder_out.encoder_states if encoder_states is not None: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) new_encoder_out["bottom_features"] = ( encoder_out.bottom_features if encoder_out.bottom_features is None else encoder_out.bottom_features.index_select(0, new_order)) return EncoderOut( encoder_out=new_encoder_out["encoder_out"], # T x B x C encoder_padding_mask=new_encoder_out[ "encoder_padding_mask"], # B x T encoder_embedding=new_encoder_out[ "encoder_embedding"], # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, # B x T src_lengths=src_lengths, # B x 1 bottom_features=new_encoder_out["bottom_features"], # B x T' )
def forward(self, src_tokens, src_lengths=None, **kwargs): b_sz, t_sz = src_tokens.shape padding_needed = t_sz % 2 x = src_tokens if padding_needed > 0: padding_needed = 2 - padding_needed x = F.pad(x, (0, padding_needed)) return EncoderOut( encoder_out=x.view(b_sz, -1, 2), encoder_padding_mask=None, encoder_embedding=None, encoder_states=None, )
def combine_encoder_out(self, outs): encoder_out = torch.cat([out[0] for out in outs], 0) encoder_padding_mask = torch.cat([out[1] for out in outs], 1) encoder_embedding = torch.cat([out[2] for out in outs], 1) encoder_states = None if all(out[3] is not None for out in outs): encoder_states = torch.cat([out[3] for out in outs], 0) return EncoderOut( encoder_out=encoder_out, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None)
def reorder_encoder_out(self, encoder_out:EncoderOut, new_order): if encoder_out.encoder_padding_mask is not None: epm = encoder_out.encoder_padding_mask.index_select(1, new_order) else: epm = encoder_out.encoder_padding_mask return EncoderOut( encoder_out=encoder_out.encoder_out.index_select( 1, new_order ), # T x B x C encoder_padding_mask=epm, # B x T encoder_embedding=None, # B x T x C encoder_states=None, # List[T x B x C] )
def forward(self, token_id, mask_label=None, decode_label=None, label=None):# # batch_size,can_num,can_legth=candidate_id.shape # batch_size,_,his_length=his_id.shape #print('???shape: ',token_id.shape,mask_label.shape,decode_label.shape) if label is not None: return self.predict(token_id,label) token_features,_ = self.encoder(token_id)#bsz,length,dim token_features=token_features[-1].transpose(0,1)#[:,0,:] loss_mask, sample_size_mask = self.predict_mask(token_features, mask_label) h=token_features[:,0:,] h=EncoderOut( encoder_out=h, # T x B x C encoder_padding_mask=None, # B x T encoder_embedding=None, # B x T x C encoder_states=None, # List[T x B x C] src_tokens=None, src_lengths=None, ) loss_decode, sample_size_decode =self.predict_decode(h ,decode_label) # loss = F.nll_loss( # F.log_softmax( # res.view(-1, res.size(-1)), # dim=-1, # dtype=torch.float32, # ), # label.view(-1), # reduction='sum', # #ignore_index=self.padding_idx, # ) #loss=0.5*loss_decode+0.5*loss_mask # loss=loss_mask # sample_size= sample_size_mask # loss=loss_decode # sample_size= sample_size_decode #return loss, sample_size #,torch.tensor(sample_size).cuda() return loss_mask,sample_size_mask,loss_decode,sample_size_decode
def score(self, src_tokens, tgt_tokens): src_tokens = src_tokens[:, 1:] assert src_tokens.shape[0] == 1 unique_tgt_tokens = tgt_tokens.unique(dim=0) x = src_tokens[0].cpu().numpy() for i in range(x.shape[0]): x[i] = self.src_vmap[x[i]] y = unique_tgt_tokens.cpu().numpy() for r in range(y.shape[0]): pad = False for c in range(y.shape[1]): if pad: y[r][c] = 1 else: if y[r][c] == 2: pad = True y[r][c] = self.tgt_vmap[y[r][c]] B = unique_tgt_tokens.shape[0] with torch.no_grad(): x_tensor = torch.tensor(x)[None, :].cuda() y_tensor = torch.tensor(y).cuda() x_lens = torch.tensor([x_tensor.shape[1]]).cuda() y_lens = torch.ne(y_tensor, 1).sum(1) - 1 # Transformer forward >>> encoder_out = self.transformer.encoder(x_tensor, src_lengths=x_lens, return_all_hiddens=False) encoder_out = EncoderOut( encoder_out.encoder_out.repeat(1, B, 1), encoder_out.encoder_padding_mask.repeat(B, 1), encoder_out.encoder_embedding.repeat(B, 1, 1), None, None, None) decoder_out = self.transformer.decoder( y_tensor[:, :-1], encoder_out=encoder_out, src_lengths=x_lens.repeat(B), return_all_hiddens=False, ) logits = decoder_out[0] # <<< logp = torch.log_softmax(logits, 2) _, L, V = logp.shape token_logp = logp.view(B * L, V)[torch.arange(B * L), y_tensor[:, 1:].flatten()].view(B, L) y_mask = torch.arange(L).unsqueeze(0).repeat( B, 1).cuda() < y_lens[:, None] scores = (token_logp * y_mask).sum(1) / y_mask.sum(1) return unique_tgt_tokens, scores
def apply_adapter(self, enc_out): if self.adapter is None: return enc_out rst = self.adapter(enc_out.encoder_out) if enc_out.encoder_padding_mask is not None: rst.masked_fill_( enc_out.encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0) return EncoderOut( encoder_out=rst, encoder_padding_mask=enc_out.encoder_padding_mask, encoder_embedding=enc_out.encoder_embedding, encoder_states=enc_out.encoder_states, src_tokens=enc_out.src_tokens, src_lengths=enc_out.src_lengths, )
def forward( self, src_tokens, src_lengths, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): # B x C x H x W x = src_tokens # TODO: compute padding mask from lengths # see causal ST encoder (w/ subsampler) for reference of how to get mask # from fairseq.data.data_utils import lengths_to_padding_mask encoder_padding_mask = None encoder_states = [] if return_all_hiddens else None # encoder layers for block in self.vggblocks: x, extra1 = block(x, return_all_hiddens=return_all_hiddens) if return_all_hiddens: encoder_states.extend(extra1) # B x C x H x W -> B x (HxW) x C _b, _c, _h, _w = x.size() x = x.view(_b, _c, _h * _w).permute(0, 2, 1) if self.embed_positions is not None: pos = x.new_ones((_b, _h * _w)) # B x T if pos.size(-1) > self.max_positions(): pdb.set_trace() raise ValueError( "tokens exceeds maximum length: {} > {}".format( pos.size(-1), self.max_positions())) x = x + self.embed_positions(pos) # input: B x T # B x (HxW) x C -> (HxW) x B x C x = x.permute(1, 0, 2) x = self.dropout_module(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=None, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, )
def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): encoder_padding_mask: Optional[ Tensor] = encoder_out.encoder_padding_mask src_lengths: Optional[Tensor] = encoder_out.src_lengths new_encoder_padding_mask = ( encoder_padding_mask if encoder_padding_mask is None else encoder_padding_mask.index_select(1, new_order)) new_src_lengths = (src_lengths if src_lengths is None else src_lengths.index_select(0, new_order)) return EncoderOut( encoder_out=encoder_out.encoder_out.index_select(1, new_order), encoder_padding_mask=new_encoder_padding_mask, encoder_embedding=None, encoder_states=None, src_tokens=None, src_lengths=new_src_lengths, )
def forward(self, src_tokens, src_lengths, **kwargs): self.wav2vec_model.eval() with torch.no_grad(): z = self.wav2vec_model.feature_extractor(src_tokens.squeeze()) c = self.wav2vec_model.feature_aggregator(z).permute(0,2,1) subsample_factor = src_tokens.shape[1]/c.shape[1] src_lengths = torch.ceil(src_lengths /subsample_factor).type(torch.int64) src_lengths = torch.min(torch.tensor(c.shape[1]).to(src_lengths.device),src_lengths) d = super().forward(c, src_lengths, **kwargs) epm = d.get('encoder_padding_mask', None) epm = epm.t() if epm is not None else None return EncoderOut( encoder_out=d['encoder_out'], # T x B x C encoder_padding_mask=epm, # B x T encoder_embedding=None, # B x T x C encoder_states=None, # List[T x B x C] )
def forward(self, tgt, enc_out, src_len): assert self.is_initialized if self.impl == "fairseq": B, L, H = enc_out.shape encoder_out = EncoderOut( enc_out.transpose(0, 1), torch.arange(L, device=src_len.device).unsqueeze(0).expand( (B, L)) - src_len.unsqueeze(1) >= 1, None, None, None, None) output, _ = self.model.forward(tgt, encoder_out=encoder_out, src_lengths=src_len) return output
def forward( self, src_tokens, src_lengths, cls_input: Optional[Tensor] = None, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens) x = x.transpose(0, 1) # B x T x C -> T x B x C # U-Net part: encoder_padding_mask = src_tokens.eq(self.padding_idx) x = self.forward_unet(x, encoder_padding_mask) # if not return_all hiddens, encoder states are expected to be an empty list encoder_states = [] return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, src_lengths=src_lengths, )
def forward( self, src_tokens, src_lengths: Tensor, enforce_sorted: bool = True, **unused, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (LongTensor): lengths of each source sentence of shape `(batch)` enforce_sorted (bool, optional): if True, `src_tokens` is expected to contain sequences sorted by length in a decreasing order. If False, this condition is not required. Default: True. """ out = super().forward(src_tokens, src_lengths, enforce_sorted=enforce_sorted, **unused) x, encoder_padding_mask, x_lengths = out.encoder_out, out.encoder_padding_mask, out.src_lengths # determine which output frame to select for loss evaluation/test, assuming # all examples in a batch are of the same length for chunk-wise training/test if (self.out_chunk_end is not None and (self.training or not self.training_stage)): x = x[self.out_chunk_begin: self.out_chunk_end] # T x B x C -> W x B x C x_lengths = x_lengths.fill_(x.size(0)) assert encoder_padding_mask is None if self.fc_out is not None: x = self.fc_out(x) # T x B x C -> T x B x V return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B encoder_embedding=None, encoder_states=None, src_tokens=None, src_lengths=x_lengths, # B )