def forward( self, src_tokens, src_lengths, max_target_position, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # assume that we obtain an embedding for the target positions called target_pos (B x T x C) # max_target_position = torch.max(target_pos) # create position table based of the greatest target position (for now lets just use sentence length of source) num_sentences, src_len, d = x.size() pos_table = self.constant_positional_encoding[:max_target_position + 1] # pos_table: B x T x C pos_table = pos_table.repeat(num_sentences, 1).view(num_sentences, max_target_position + 1, -1) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None # get the specified list of layers which we want to have position attention occur position_layer_list = set(self.positional_layers) num_layers = len(position_layer_list) probability = torch.empty(size=(num_layers, num_sentences, src_len, max_target_position + 1)) # encoder layers where the count starts at 1 for count, layer in enumerate(self.layers, 1): if count in position_layer_list: reordered_position, pos_attention = position_attention(x, pos_table, 1) probability[count-1] = pos_attention x = x + reordered_position x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, ), probability
def forward( self, src_tokens, src_lengths, cls_input: Optional[Tensor] = None, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ if self.layer_wise_attention: return_all_hiddens = True x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = torch.empty(1).uniform_() if not self.training or (dropout_probability > self.encoder_layerdrop): x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) if return_all_hiddens: encoder_states[-1] = x return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] )
def forward( self, src_tokens, src_lengths, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ if self.conv_layers_before is not None: x, src_lengths, encoder_padding_mask = self.conv_layers_before( src_tokens, src_lengths) else: x, encoder_padding_mask = src_tokens, \ ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) x = self.dropout_module(x) if self.fc0 is not None: x = self.fc0(x) if self.embed_positions is not None: # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings x = x + self.embed_positions((~encoder_padding_mask).int()) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) elif self.embed_positions is not None: # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings x = x + self.embed_positions((~encoder_padding_mask).int()) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) if not encoder_padding_mask.any(): encoder_padding_mask = None # B x T x C -> T x B x C x = x.transpose(0, 1) attn_mask = self.get_attn_mask(src_lengths) encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask, attn_mask=attn_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=None, encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, )
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_states = [] if return_all_hiddens else None self_attn_at_list = [] # encoder layers for layer in self.layers: x, self_attn_at = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) self_attn_at_list.append(self_attn_at.transpose(1, 0).contiguous()) self_attn_at_tensor = None if self_attn_at_list is not None: self_attn_at_tensor = torch.stack( self_attn_at_list, dim=0) # (layers, batch, heads, tgt_len, src_len) self_attn_at_tensor = self_attn_at_tensor.transpose( 0, 1).contiguous() if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=None, src_lengths=None, self_attn_at_tensor=self_attn_at_tensor, )
def forward( self, src_tokens: Tensor, src_lengths: Tensor, enforce_sorted: bool = True, **unused, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (LongTensor): lengths of each source sentence of shape `(batch)` enforce_sorted (bool, optional): if True, `src_tokens` is expected to contain sequences sorted by length in a decreasing order. If False, this condition is not required. Default: True. """ if self.left_pad: # nn.utils.rnn.pack_padded_sequence requires right-padding; # convert left-padding to right-padding src_tokens = speech_utils.convert_padding_direction( src_tokens, src_lengths, left_to_right=True, ) if self.conv_layers_before is not None: x, src_lengths, padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: x, padding_mask = src_tokens, \ ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) bsz, seqlen = x.size(0), x.size(1) x = F.dropout(x, p=self.dropout_in, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) state_size = 2 if self.bidirectional else 1, bsz, self.hidden_size h0, c0 = x.new_zeros(*state_size), x.new_zeros(*state_size) for i in range(len(self.lstm)): if self.residual and i > 0: # residual connection starts from the 2nd layer prev_x = x # pack embedded source tokens into a PackedSequence packed_x = nn.utils.rnn.pack_padded_sequence( x, src_lengths.data, enforce_sorted=enforce_sorted ) # apply LSTM packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0)) # unpack outputs and apply dropout x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value*1.0) if i < len(self.lstm) - 1: # not applying dropout for the last layer x = F.dropout(x, p=self.dropout_out, training=self.training) x = x + prev_x if self.residual and i > 0 else x assert list(x.size()) == [seqlen, bsz, self.output_units] encoder_padding_mask = padding_mask.t() return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B encoder_embedding=None, encoder_states=None, src_tokens=None, src_lengths=src_lengths, # B )
def forward(self, src_video, src_lengths=None, return_all_hiddens: bool = False, **kwargs): """ :param src_video: [batch, num_frames, channels, width, height] :param src_lengths: :param kwargs: :return: """ # cnn module bs, num_fm, c, h, w = src_video.size() src = src_video.view(bs * num_fm, c, h, w) spatiol_feature = self.spatio_enc( src).squeeze() # [bs * num_fm, embed_dim] if self.args.cnn_normalize_after: spatiol_feature = self.batchnorm(spatiol_feature) spatiol_feature = self.relu(spatiol_feature) spatiol_feature = spatiol_feature.view( bs, num_fm, spatiol_feature.size(-1)) # [bs, num_fm, embed_dim] position_tensor = torch.LongTensor(list( range(num_fm))).unsqueeze_(0).repeat(bs, 1).type_as(src_lengths) position_tensor = position_tensor.le( src_lengths.unsqueeze(1)) # padding 部分为 0 # add position encoding if self.embed_positions is not None: x = spatiol_feature + self.embed_positions( position_tensor) # # [bs, num_fm, embed_dim] else: x = spatiol_feature if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) x = x.transpose(0, 1) # [num_fm, bs, embed_dim] encoder_padding_mask = position_tensor.eq(self.padding_idx) if self.layer_wise_attention: return_all_hiddens = True encoder_states = [] if return_all_hiddens else None # encoder layers for layer in self.temporal_enc_layers: # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = torch.empty(1).uniform_() if not self.training or (dropout_probability > self.encoder_layerdrop): x = layer(x, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) if return_all_hiddens: encoder_states[-1] = x return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=spatiol_feature, # B x T x C encoder_states=encoder_states, # List[T x B x C] )
def forward(self, model, sample, reduction="sum", log_probs=True): encoder_output = model.encoder(tbc=False, **sample["net_input"]) ctc_logits = encoder_output['encoder_out'] len_ctc_logits = ( ~encoder_output['encoder_padding_mask']).long().sum(-1) encoder_output = EncoderOut( encoder_out=encoder_output['encoded'].transpose(0, 1), # T x B x C encoder_embedding=None, encoder_padding_mask=encoder_output[ 'encoder_padding_mask'], # B x T encoder_states=None, src_tokens=None, src_lengths=None, ) p = max((model.num_updates - model.teacher_forcing_updates) / 2000.0, 0.0) if model.num_updates <= model.teacher_forcing_updates: decoder_out = model.decoder( prev_output_tokens=sample["net_input"]["prev_output_tokens"], encoder_out=encoder_output) else: with torch.no_grad(): decoder_out = model.decoder( prev_output_tokens=sample["net_input"] ["prev_output_tokens"], encoder_out=encoder_output) decoded = decoder_out["logits"].argmax(-1).int() device = decoded.device prev_self_deocded = torch.cat([ torch.ones([decoded.size(0), 1]).int().to(device) * self.task.target_dictionary.eos(), decoded[:, :-1] ], 1) prev_output = torch.where( (torch.rand(decoded.size()) > p).to(device), sample["net_input"]["prev_output_tokens"], prev_self_deocded) decoder_out = model.decoder(prev_output_tokens=prev_output, encoder_out=encoder_output) target = sample["target"] target_lengths = sample["target_lengths"] lprobs, ctc_loss, ce_loss = self.compute_loss(model, ctc_logits, len_ctc_logits, decoder_out["logits"], target, target_lengths, reduction, log_probs) sample_size, logging_output = self.get_logging_output( sample, target, lprobs, ctc_loss, ce_loss) loss = ctc_loss + ce_loss logging_output['schedule_sampling'] = p if not model.training: import editdistance c_err = 0 c_len = 0 self.decoder.step_forward_fn = model.decoder input_lengths = (~encoder_output.encoder_padding_mask).sum(-1) with torch.no_grad(): decodeds = self.decoder.decode(encoder_output, 50) for decoded, t, inp_l in zip(decodeds, sample["target"], input_lengths): decoded = decoded[0]['tokens'] p = (t != self.task.target_dictionary.pad()) & ( t != self.task.target_dictionary.eos()) targ = t[p] targ_units_arr = targ.tolist() pred_units_arr = decoded.tolist() c_err += editdistance.eval(pred_units_arr, targ_units_arr) c_len += len(targ_units_arr) logging_output["c_errors"] = c_err logging_output["c_total"] = c_len return loss, sample_size, logging_output
def forward( self, src_tokens, src_lengths, cls_input: Optional[Tensor] = None, return_all_hiddens: bool = False, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ x, encoder_embedding = self.forward_embedding(src_tokens) x = x.transpose(0, 1) # B x T x C -> T x B x C # if not return_all hiddens, encoder states are expected to be an empty list # and we do not support encoder hiddens, but need to satisfy the interface encoder_states = [] # U-Net part: x_unet = x encoder_padding_mask = src_tokens.eq(self.padding_idx) x_unet = self.forward_unet(x_unet, encoder_padding_mask) # Transformer part: x_transformer = x for layer in self.transformer_layers: x_transformer = layer(x_transformer, encoder_padding_mask) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x_transformer) # Combine U-Net representations and Transformer representations x, _ = torch.stack([x_transformer, x_unet], dim=0).max(0) if self.layer_norm is not None: x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] src_tokens=src_tokens, src_lengths=src_lengths, )
def batch_beam_decode(encoder_output, step_forward_fn, incremental_state, SOS_ID, EOS_ID, vocab_size, beam_size=1, max_decode_len=100): """ encoder_output: encoder_out=x, # T x B x C encoder_padding_mask=encoder_padding_mask, # B x T encoder_embedding=encoder_embedding, # B x T x C encoder_states=encoder_states, # List[T x B x C] """ encoded = encoder_output.encoder_out # T x B x C len_encoded = (~encoder_output.encoder_padding_mask).sum(-1) batch_size = len_encoded.size(0) device = encoded.device d_output = vocab_size # beam search Initialize # repeat each sample in batch along the batch axis [1,2,3,4] -> [1,1,2,2,3,3,4,4] encoded = encoded[:, None, :, :].repeat( 1, beam_size, 1, 1) # [batch_size, beam_size, *, hidden_units] encoded = encoded.view(batch_size * beam_size, -1, encoded.size(-1)) len_encoded = len_encoded[:, None].repeat(1, beam_size).view( -1) # [batch_size * beam_size] encoder_padding_mask = encoder_output.encoder_padding_mask.repeat( 1, beam_size).reshape(batch_size * beam_size, -1) encoder_output = EncoderOut( encoder_out=encoded.transpose(0, 1), # T x B x C encoder_embedding=None, encoder_padding_mask=encoder_padding_mask, # B x T encoder_states=None, # List[T x B x C] src_tokens=None, src_lengths=None) # [[<S>, <S>, ..., <S>]], shape: [batch_size * beam_size, 1] preds = torch.ones([batch_size * beam_size, 1 ]).long().to(device) * SOS_ID logits = torch.zeros([batch_size * beam_size, 0, d_output]).float().to(device) len_decoded = torch.ones_like(len_encoded) # the score must be [0, -inf, -inf, ...] at init, for the preds in beam is same in init!!! scores = torch.tensor([0.0] + [-inf] * (beam_size - 1)).float().repeat(batch_size).to( device) # [batch_size * beam_size] finished = torch.zeros_like(scores).bool().to(device) # collect the initial states of lstms used in decoder. base_indices = torch.arange(batch_size)[:, None].repeat( 1, beam_size).view(-1).to(device) for _ in range(max_decode_len): # i, preds, scores, logits, len_decoded, finished decoder_output = step_forward_fn( prev_output_tokens=preds, encoder_out=encoder_output, incremental_state=incremental_state) cur_logits = decoder_output["logits"] logits = torch.cat([logits, cur_logits], 1) # [batch*beam, t, size_output] z = F.log_softmax(cur_logits[:, -1, :], dim=-1) # [batch*beam, size_output] # rank the combined scores next_scores, next_preds = torch.topk(z, k=beam_size, sorted=True, dim=-1) # beamed scores & Pruning scores = scores[:, None] + next_scores # [batch_size * beam_size, beam_size] scores = scores.view(batch_size, beam_size * beam_size) _, k_indices = torch.topk(scores, k=beam_size) k_indices = base_indices * beam_size * beam_size + k_indices.view( -1) # [batch_size * beam_size] # Update scores. scores = scores.view(-1)[k_indices] # Update predictions. next_preds = next_preds.view(-1)[k_indices] # k_indices: [0~batch*beam*beam], preds: [0~batch*beam] # preds, cache_lm, cache_decoder: these data are shared during the beam expand among vocab preds = preds[k_indices // beam_size] preds = torch.cat([preds, next_preds[:, None]], axis=1) # [batch_size * beam_size, i] has_eos = next_preds.eq(EOS_ID) finished = torch.logical_or(finished, has_eos) len_decoded += 1 - finished.int() if finished.int().sum() == finished.size(0): break len_decoded -= 1 - finished.int( ) # for decoded length cut by encoded length preds = preds[:, 1:] # tf.nn.top_k is used to sort `scores` scores_sorted, sorted = torch.topk(scores.view(batch_size, beam_size), k=beam_size, sorted=True) sorted = base_indices * beam_size + sorted.view( -1) # [batch_size * beam_size] # [batch_size * beam_size, ...] -> [batch_size, beam_size, ...] preds_sorted = preds[sorted].view( batch_size, beam_size, -1) # [batch_size, beam_size, max_length] len_decoded_sorted = len_decoded[sorted].view(batch_size, beam_size) scores_sorted = scores[sorted].view(batch_size, beam_size) return preds_sorted, len_decoded_sorted, scores_sorted