def call(self, src_token_ids, tgt_token_ids): """Takes as input the source and target token ids, and returns the estimated logits for the target sequences. Note this function should be called in training mode only. Args: src_token_ids: int tensor of shape [batch_size, src_seq_len], token ids of source sequences. tgt_token_ids: int tensor of shape [batch_size, tgt_seq_len], token ids of target sequences. Returns: logits: float tensor of shape [batch_size, tgt_seq_len, vocab_size]. """ padding_mask = utils.get_padding_mask(src_token_ids) src_token_embeddings = self._embedding_logits_layer( src_token_ids, 'embedding') tgt_token_embeddings = self._embedding_logits_layer( tgt_token_ids, 'embedding') encoder_outputs, fw_states, bw_states = self._encoder( src_token_embeddings, padding_mask, training=True) decoder_outputs = self._decoder(tgt_token_embeddings, fw_states, bw_states, encoder_outputs, padding_mask, training=True) logits = self._embedding_logits_layer(decoder_outputs, 'logits') return logits
def forward(self, x, x_len): # Input has size batch_size x sequence_length x num_channels (B x L x C) if self.fea_dr > 0: x = self.fea_dr_layer(x) if self.params.attn_layer>0: x= x.transpose(0,1) # (LxBxC) mask = utils.get_padding_mask(x, x_len) x= self.attn(x, src_key_padding_mask=mask ) x= x.transpose(0,1) if self.params.tcn_layer > 0: # Transform to (B, C, L) first x = x.permute(0, 2, 1) x = self.tcn(x) # Transform back to (B, L, C) x = x.permute(0, 2, 1) if self.params.rnn_n_layers > 0: x = self.rnn(x, x_len) x= self._regression(x) return x
def transduce(self, src_token_ids): """Takes as input the source token ids only, and outputs the token ids of the decoded target sequences using beam search. Note this function should be called in inference mode only. Args: src_token_ids: int tensor of shape [batch_size, src_seq_len], token ids of source sequences. Returns: decoded_ids: int tensor of shape [batch_size, decoded_seq_len], the token ids of the decoded target sequences using beam search. scores: float tensor of shape [batch_size], the scores (length-normalized log-probs) of the decoded target sequences. tgt_src_attention: a list of `decoder_stack_size` float tensor of shape [batch_size, num_heads, decoded_seq_len, src_seq_len], target-to-source attention weights. tgt_src_attention: float tensor of shape [batch_size, tgt_seq_len, src_seq_len], the target-to-source attention weights. """ batch_size, src_seq_len = src_token_ids.shape hidden_size = self._hidden_size max_decode_length = src_seq_len + self._extra_decode_length decoding_fn = self._build_decoding_fn() src_token_embeddings = self._embedding_logits_layer( src_token_ids, 'embedding') padding_mask = utils.get_padding_mask(src_token_ids) encoder_outputs, fw_states, bw_states = self._encoder( src_token_embeddings, padding_mask, training=False) decoding_cache = {'fw_states': fw_states, 'bw_states': bw_states, 'attention_states': tf.zeros((batch_size, hidden_size)), 'encoder_outputs': encoder_outputs, 'padding_mask': padding_mask, 'tgt_src_attention':tf.zeros((batch_size, 0, src_seq_len)) } sos_ids = tf.ones([batch_size], dtype='int32') * SOS_ID bs = beam_search.BeamSearch(decoding_fn, self._vocab_size, batch_size, self._beam_width, self._alpha, max_decode_length, EOS_ID) decoded_ids, scores, decoding_cache = bs.search(sos_ids, decoding_cache) tgt_src_attention = decoding_cache['tgt_src_attention'].numpy()[:, 0] decoded_ids = decoded_ids[:, 0, 1:] scores = scores[:, 0] return decoded_ids, scores, tgt_src_attention
def _build_decoding_cache(self, src_token_ids, batch_size): """Builds a dictionary that caches previously computed key and value feature maps and attention weights of the growing decoded sequence. Args: src_token_ids: int tensor of shape [batch_size, src_seq_len], token ids of source sequences. batch_size: int scalar, num of sequences in a batch. Returns: decoding_cache: dict of entries 'encoder_outputs': tensor of shape [batch_size, src_seq_len, hidden_size], 'padding_mask': tensor of shape [batch_size, 1, 1, src_seq_len], and entries with keys 'layer_0',...,'layer_[decoder_num_layers - 1]' where the value associated with key 'layer_*' is a dict with entries 'k': tensor of shape [batch_size, 0, num_heads, size_per_head], 'v': tensor of shape [batch_size, 0, num_heads, size_per_head], 'tgt_tgt_attention': tensor of shape [batch_size, num_heads, 0, 0], 'tgt_src_attention': tensor of shape [batch_size, num_heads, 0, src_seq_len]. """ padding_mask = utils.get_padding_mask(src_token_ids, SOS_ID) encoder_outputs = self._encode(src_token_ids, padding_mask, training=False) size_per_head = self._hidden_size // self._num_heads src_seq_len = padding_mask.shape[-1] decoding_cache = { 'layer_%d' % layer: { 'k': tf.zeros([batch_size, 0, self._num_heads, size_per_head], 'float32'), 'v': tf.zeros([batch_size, 0, self._num_heads, size_per_head], 'float32'), 'tgt_tgt_attention': tf.zeros([batch_size, self._num_heads, 0, 0], 'float32'), 'tgt_src_attention': tf.zeros([batch_size, self._num_heads, 0, src_seq_len], 'float32') } for layer in range(self._decoder._stack_size) } decoding_cache['encoder_outputs'] = encoder_outputs decoding_cache['padding_mask'] = padding_mask return decoding_cache
def forward(self, x, x_len): if self.params.d_in != self.params.d_rnn: x = self.proj(x) if self.params.attn == True: x = x.transpose(0, 1) # (seq_len, batch_size, feature_dim) mask = utils.get_padding_mask(x, x_len) x = self.attn(x, mask) x = x.transpose(0, 1) # (batch_size, seq_len, feature_dim) if self.params.rnn_n_layers > 0: x = self.rnn(x, x_len) y = self.out(x) return y
def forward(self, x, x_len): if self.params.d_in != self.params.d_rnn and not self.params.transformer: x = self.proj(x) if self.params.transformer: f_l = x[:, :, :self.params.feature_dims[0]] f_a = x[:, :, self.params.feature_dims[0]:self.params.feature_dims[0] + self.params.feature_dims[1]] f_v = x[:, :, self.params.feature_dims[0] + self.params.feature_dims[1]:] # split features, because mmt needs separate modalities x, _ = self.mmt(f_l, f_a, f_v) if self.params.attn: x = x.transpose(0, 1) # (seq_len, batch_size, feature_dim) mask = utils.get_padding_mask(x, x_len) x = self.attn(x, mask) x = x.transpose(0, 1) # (batch_size, seq_len, feature_dim) if self.params.rnn_n_layers > 0: x = self.rnn(x, x_len) y = self.out(x) return y