Beispiel #1
0
  def call(self, src_token_ids, tgt_token_ids):
    """Takes as input the source and target token ids, and returns the estimated
    logits for the target sequences. Note this function should be called in 
    training mode only.

    Args:
      src_token_ids: int tensor of shape [batch_size, src_seq_len], token ids
        of source sequences.
      tgt_token_ids: int tensor of shape [batch_size, tgt_seq_len], token ids 
        of target sequences.

    Returns:
      logits: float tensor of shape [batch_size, tgt_seq_len, vocab_size].
    """
    padding_mask = utils.get_padding_mask(src_token_ids)

    src_token_embeddings = self._embedding_logits_layer(
        src_token_ids, 'embedding')
    tgt_token_embeddings = self._embedding_logits_layer(
        tgt_token_ids, 'embedding')

    encoder_outputs, fw_states, bw_states = self._encoder(
        src_token_embeddings, padding_mask, training=True)
    decoder_outputs = self._decoder(tgt_token_embeddings,
                                    fw_states,
                                    bw_states,
                                    encoder_outputs,
                                    padding_mask,
                                    training=True)

    logits = self._embedding_logits_layer(decoder_outputs, 'logits') 
    return logits
Beispiel #2
0
    def forward(self, x, x_len):
        # Input has size batch_size x sequence_length x num_channels (B x L x C)
        if self.fea_dr > 0:
            x = self.fea_dr_layer(x)

        if self.params.attn_layer>0:
            x= x.transpose(0,1)  # (LxBxC)
            mask = utils.get_padding_mask(x, x_len)
            x= self.attn(x, src_key_padding_mask=mask )
            x= x.transpose(0,1)
        

        if self.params.tcn_layer > 0:
            # Transform to (B, C, L) first
            x = x.permute(0, 2, 1)
            x = self.tcn(x)
            # Transform back to (B, L, C)
            x = x.permute(0, 2, 1)

        if self.params.rnn_n_layers > 0:
            x = self.rnn(x, x_len)

        

        x= self._regression(x)
        return x
Beispiel #3
0
  def transduce(self, src_token_ids):
    """Takes as input the source token ids only, and outputs the token ids of 
    the decoded target sequences using beam search. Note this function should be 
    called in inference mode only.

    Args:
      src_token_ids: int tensor of shape [batch_size, src_seq_len], token ids
        of source sequences.

    Returns:
      decoded_ids: int tensor of shape [batch_size, decoded_seq_len], the token
        ids of the decoded target sequences using beam search.
      scores: float tensor of shape [batch_size], the scores (length-normalized 
        log-probs) of the decoded target sequences.
      tgt_src_attention: a list of `decoder_stack_size` float tensor of shape 
        [batch_size, num_heads, decoded_seq_len, src_seq_len], target-to-source 
        attention weights.
      tgt_src_attention: float tensor of shape [batch_size, tgt_seq_len, 
        src_seq_len], the target-to-source attention weights.
    """
    batch_size, src_seq_len = src_token_ids.shape
    hidden_size = self._hidden_size
    max_decode_length = src_seq_len + self._extra_decode_length
    decoding_fn = self._build_decoding_fn()

    src_token_embeddings = self._embedding_logits_layer(
        src_token_ids, 'embedding')
    padding_mask = utils.get_padding_mask(src_token_ids)
    encoder_outputs, fw_states, bw_states = self._encoder(
        src_token_embeddings, padding_mask, training=False)
    decoding_cache = {'fw_states': fw_states,
                      'bw_states': bw_states,
                      'attention_states': tf.zeros((batch_size, hidden_size)),
                      'encoder_outputs': encoder_outputs,
                      'padding_mask': padding_mask,
                      'tgt_src_attention':tf.zeros((batch_size, 0, src_seq_len))
                      }
    sos_ids = tf.ones([batch_size], dtype='int32') * SOS_ID

    bs = beam_search.BeamSearch(decoding_fn,
                                self._vocab_size,
                                batch_size,
                                self._beam_width,
                                self._alpha,
                                max_decode_length,
                                EOS_ID)

    decoded_ids, scores, decoding_cache = bs.search(sos_ids, decoding_cache)

    tgt_src_attention = decoding_cache['tgt_src_attention'].numpy()[:, 0]

    decoded_ids = decoded_ids[:, 0, 1:]
    scores = scores[:, 0]
    return decoded_ids, scores, tgt_src_attention
Beispiel #4
0
    def _build_decoding_cache(self, src_token_ids, batch_size):
        """Builds a dictionary that caches previously computed key and value feature
    maps and attention weights of the growing decoded sequence.

    Args:
      src_token_ids: int tensor of shape [batch_size, src_seq_len], token ids of 
        source sequences. 
      batch_size: int scalar, num of sequences in a batch.

    Returns:
      decoding_cache: dict of entries
        'encoder_outputs': tensor of shape [batch_size, src_seq_len, 
          hidden_size],
        'padding_mask': tensor of shape [batch_size, 1, 1, src_seq_len],

        and entries with keys 'layer_0',...,'layer_[decoder_num_layers - 1]'
        where the value associated with key 'layer_*' is a dict with entries
          'k': tensor of shape [batch_size, 0, num_heads, size_per_head],
          'v': tensor of shape [batch_size, 0, num_heads, size_per_head],
          'tgt_tgt_attention': tensor of shape [batch_size, num_heads, 
            0, 0],
          'tgt_src_attention': tensor of shape [batch_size, num_heads,
            0, src_seq_len].
    """
        padding_mask = utils.get_padding_mask(src_token_ids, SOS_ID)
        encoder_outputs = self._encode(src_token_ids,
                                       padding_mask,
                                       training=False)
        size_per_head = self._hidden_size // self._num_heads
        src_seq_len = padding_mask.shape[-1]

        decoding_cache = {
            'layer_%d' % layer: {
                'k':
                tf.zeros([batch_size, 0, self._num_heads, size_per_head],
                         'float32'),
                'v':
                tf.zeros([batch_size, 0, self._num_heads, size_per_head],
                         'float32'),
                'tgt_tgt_attention':
                tf.zeros([batch_size, self._num_heads, 0, 0], 'float32'),
                'tgt_src_attention':
                tf.zeros([batch_size, self._num_heads, 0, src_seq_len],
                         'float32')
            }
            for layer in range(self._decoder._stack_size)
        }
        decoding_cache['encoder_outputs'] = encoder_outputs
        decoding_cache['padding_mask'] = padding_mask
        return decoding_cache
Beispiel #5
0
    def forward(self, x, x_len):
        if self.params.d_in != self.params.d_rnn:
            x = self.proj(x)

        if self.params.attn == True:
            x = x.transpose(0, 1)  # (seq_len, batch_size, feature_dim)
            mask = utils.get_padding_mask(x, x_len)
            x = self.attn(x, mask)
            x = x.transpose(0, 1)  # (batch_size, seq_len, feature_dim)

        if self.params.rnn_n_layers > 0:
            x = self.rnn(x, x_len)

        y = self.out(x)
        return y
    def forward(self, x, x_len):
        if self.params.d_in != self.params.d_rnn and not self.params.transformer:
            x = self.proj(x)
        if self.params.transformer:
            f_l = x[:, :, :self.params.feature_dims[0]]
            f_a = x[:, :,
                    self.params.feature_dims[0]:self.params.feature_dims[0] +
                    self.params.feature_dims[1]]
            f_v = x[:, :,
                    self.params.feature_dims[0] + self.params.feature_dims[1]:]
            # split features, because mmt needs separate modalities
            x, _ = self.mmt(f_l, f_a, f_v)
        if self.params.attn:
            x = x.transpose(0, 1)  # (seq_len, batch_size, feature_dim)
            mask = utils.get_padding_mask(x, x_len)
            x = self.attn(x, mask)
            x = x.transpose(0, 1)  # (batch_size, seq_len, feature_dim)

        if self.params.rnn_n_layers > 0:
            x = self.rnn(x, x_len)

        y = self.out(x)
        return y