def decode(self,
             decoder_input,
             encoder_output,
             encoder_decoder_attention_bias,
             decoder_self_attention_bias,
             hparams,
             cache=None,
             decode_loop_step=None,
             nonpadding=None,
             losses=None):
    """Decode Universal Transformer outputs from encoder representation.

    It is similar to "transformer.decode", but it uses
    "universal_transformer_util.universal_transformer_decoder" instead of
    "transformer.transformer_decoder".

    Args:
      decoder_input: inputs to bottom of the model. [batch_size, decoder_length,
        hidden_dim]
      encoder_output: Encoder representation. [batch_size, input_length,
        hidden_dim]
      encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder
        attention. [batch_size, input_length]
      decoder_self_attention_bias: Bias and mask weights for decoder
        self-attention. [batch_size, decoder_length]
      hparams: hyperparmeters for model.
      cache: Unimplemented.
      decode_loop_step: Unused.
      nonpadding: optional Tensor with shape [batch_size, decoder_length]
      losses: Unused.

    Returns:
       Tuple of:
         Final decoder representation. [batch_size, decoder_length,
            hidden_dim]
         encoder_extra_output: which is extra encoder output used in some
            variants of the model (e.g. in ACT, to pass the ponder-time to body)

    """
    del decode_loop_step
    del losses
    # TODO(dehghani): enable caching.
    del cache

    decoder_input = tf.nn.dropout(decoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    # No caching in Universal Transformers!
    (decoder_output, dec_extra_output) = (
        universal_transformer_util.universal_transformer_decoder(
            decoder_input,
            encoder_output,
            decoder_self_attention_bias,
            encoder_decoder_attention_bias,
            hparams,
            nonpadding=nonpadding,
            save_weights_to=self.attention_weights))

    # Expand since t2t expects 4d tensors.
    return tf.expand_dims(decoder_output, axis=2), dec_extra_output
  def decode(self,
             decoder_input,
             encoder_output,
             encoder_decoder_attention_bias,
             decoder_self_attention_bias,
             hparams,
             cache=None,
             decode_loop_step=None,
             nonpadding=None,
             losses=None):
    """Decode Universal Transformer outputs from encoder representation.

    It is similar to "transformer.decode", but it uses
    "universal_transformer_util.universal_transformer_decoder" instead of
    "transformer.transformer_decoder".

    Args:
      decoder_input: inputs to bottom of the model. [batch_size, decoder_length,
        hidden_dim]
      encoder_output: Encoder representation. [batch_size, input_length,
        hidden_dim]
      encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder
        attention. [batch_size, input_length]
      decoder_self_attention_bias: Bias and mask weights for decoder
        self-attention. [batch_size, decoder_length]
      hparams: hyperparmeters for model.
      cache: Unimplemented.
      decode_loop_step: Unused.
      nonpadding: optional Tensor with shape [batch_size, decoder_length]
      losses: Unused.

    Returns:
       Tuple of:
         Final decoder representation. [batch_size, decoder_length,
            hidden_dim]
         encoder_extra_output: which is extra encoder output used in some
            variants of the model (e.g. in ACT, to pass the ponder-time to body)

    """
    del decode_loop_step
    del losses
    # TODO(dehghani): enable caching.
    del cache

    decoder_input = tf.nn.dropout(decoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    # No caching in Universal Transformers!
    (decoder_output, dec_extra_output) = (
        universal_transformer_util.universal_transformer_decoder(
            decoder_input,
            encoder_output,
            decoder_self_attention_bias,
            encoder_decoder_attention_bias,
            hparams,
            nonpadding=nonpadding,
            save_weights_to=self.attention_weights))

    # Expand since t2t expects 4d tensors.
    return tf.expand_dims(decoder_output, axis=2), dec_extra_output
    def decode_inputs_to_outputs(self, decoder_embed_inputs, encoder_outputs, encoder_attn_bias,
                                 rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                                 score, obj_tensors=None):
        if self.hparams.pos == 'timing':
            decoder_embed_inputs = common_attention.add_timing_signal_1d(decoder_embed_inputs)
            print('Use positional encoding in decoder text.')
        decoder_embed_inputs = self.update_decoder_embedding(decoder_embed_inputs, score, self.model_config.beam_search_size)

        decoder_attn_bias = common_attention.attention_bias_lower_triangle(tf.shape(decoder_embed_inputs)[1])
        decoder_embed_inputs = tf.nn.dropout(decoder_embed_inputs,
                                             1.0 - self.hparams.layer_prepostprocess_dropout)
        if 'direct' in self.model_config.memory:
            assert 'direct_bert_output' in obj_tensors
            decoder_output = transformer.transformer_multi_decoder(
                decoder_embed_inputs, encoder_outputs, decoder_attn_bias,
                encoder_attn_bias, obj_tensors['direct_bert_output'], obj_tensors['direct_bert_bias'],
                self.hparams, save_weights_to=obj_tensors,
                direct_mode=self.model_config.direct_mode)

            if self.model_config.npad_mode == 'static_seq':
                decoder_output = tf.nn.conv1d(decoder_output, obj_tensors['npad_w'], 1, 'SAME')

            return decoder_output, decoder_output, None
        elif 'rule' in self.model_config.memory:
            decoder_output, contexts = transformer.transformer_decoder_contexts(
                decoder_embed_inputs, encoder_outputs, decoder_attn_bias,
                encoder_attn_bias, self.hparams)

            # encoder_gate_w = tf.get_variable('encoder_gate_w', shape=(
            #     1, self.model_config.dimension, 1))
            # encoder_gate_b = tf.get_variable('encoder_gate_b', shape=(1, 1, 1))
            # encoder_gate = tf.tanh(encoder_gate_b + tf.nn.conv1d(encoder_outputs, encoder_gate_w, 1, 'SAME'))
            # encoder_context_outputs = tf.expand_dims(tf.reduce_mean(encoder_outputs * encoder_gate, axis=1), axis=1)
            cur_context = contexts[0] #tf.concat(contexts, axis=-1)
            cur_mem_contexts = tf.stack(self.embedding_fn(rule_id_input_placeholder, mem_contexts), axis=1)
            cur_mem_outputs = tf.stack(self.embedding_fn(rule_id_input_placeholder, mem_outputs), axis=1)
            cur_mem_contexts = tf.reshape(cur_mem_contexts,
                                          [self.model_config.batch_size,
                                           self.model_config.max_target_rule_sublen*self.model_config.max_cand_rules,
                                           self.model_config.dimension])
            cur_mem_outputs = tf.reshape(cur_mem_outputs,
                                         [self.model_config.batch_size,
                                          self.model_config.max_target_rule_sublen*self.model_config.max_cand_rules,
                                          self.model_config.dimension])

            # bias = tf.expand_dims(
            #     -1e9 * tf.to_float(tf.equal(tf.stack(rule_id_input_placeholder, axis=1), 0)),
            #     axis=1)
            # weights = tf.nn.softmax(bias + tf.matmul(cur_context, cur_mem_contexts, transpose_b=True))
            weights = tf.nn.softmax(tf.matmul(cur_context, cur_mem_contexts, transpose_b=True))
            mem_output = tf.matmul(weights, cur_mem_outputs)

            # trainable_mem = 'stopgrad' not in self.model_config.rl_configs
            temp_output = tf.concat((decoder_output, mem_output), axis=-1)
            # w_u = tf.get_variable('w_ffn', shape=(
            #     1, self.model_config.dimension*2, self.model_config.dimension), trainable=trainable_mem)
            # b_u = tf.get_variable('b_ffn', shape=(
            #     1, 1, self.model_config.dimension), trainable=trainable_mem)
            # w_u.reuse_variables()
            # b_u.reuse_variables()
            # tf.get_variable_scope().reuse_variables()
            w_t = tf.get_variable('w_ffn', shape=(
                1, self.model_config.dimension*2, self.model_config.dimension), trainable=True)
            b_t = tf.get_variable('b_ffn', shape=(
                1, 1, self.model_config.dimension), trainable=True)
            # w = tf.cond(tf.equal(tf.mod(self.global_step, 2), 0), lambda: w_t, lambda: w_u)
            # b = tf.cond(tf.equal(tf.mod(self.global_step, 2), 0), lambda: b_t, lambda: b_u)

            mem_output = tf.nn.conv1d(temp_output, w_t, 1, 'SAME') + b_t
            g = tf.greater(global_step, tf.constant(self.model_config.memory_prepare_step, dtype=tf.int64))
            final_output = tf.cond(g, lambda: mem_output, lambda: decoder_output)
            return final_output, decoder_output, cur_context
        else:
            if self.model_config.architecture == 'ut2t':
                (decoder_output, decoder_extra_output) = universal_transformer_util.universal_transformer_decoder(
                    decoder_embed_inputs, encoder_outputs,
                    decoder_attn_bias, encoder_attn_bias, self.hparams,
                    save_weights_to=obj_tensors)
                dec_ponder_times, dec_remainders = decoder_extra_output
                extra_dec_loss = (
                        self.hparams.act_loss_weight *
                        tf.reduce_mean(dec_ponder_times + dec_remainders))
                if self.is_train:
                    obj_tensors['extra_decoder_loss'] = extra_dec_loss
            else:
                decoder_output = transformer.transformer_decoder(
                    decoder_embed_inputs, encoder_outputs, decoder_attn_bias,
                    encoder_attn_bias, self.hparams, save_weights_to=obj_tensors,
                    npad_mode=self.model_config.npad_mode)
            final_output = decoder_output
            return final_output, decoder_output, None
Esempio n. 4
0
    def context_encoder(self, contexts_emb, contexts, abbr_inp_emb=None):
        """

        :param contexts_emb: a tensor of [batch_size, max_context_len, emb_dim]
        :param contexts: a list of [max_context_len, batch_size]
        :param abbr_inp_emb: a tensor of [batch_size, context_len, emb_dim], in transformer_abbr_encoder
        :return:
            encoder_output: [batch_size, context_len, channel_dim]
            weights: a list of multihead weights, num_layer elements,
                     each of which is [batch_size, num_head, context_len, context_len]
            extra_loss: None
        """
        weights = {}
        # Create an bias tensor as mask (big neg values for padded part), input=[batch_size, context_len], output=[batch_size, 1, 1, context_len]
        contexts_bias = common_attention.attention_bias_ignore_padding(
            tf.to_float(
                tf.equal(tf.stack(contexts, axis=1),
                         self.voc.encode(constant.PAD))))
        # add dropout to context input [batch_size, max_context_len, emb_dim]
        contexts_emb = tf.nn.dropout(
            contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout)
        # get the output vector of transformer, [batch_size, context_len, channel_dim]
        # encoder_ouput = transformer.transformer_encoder_abbr(
        #     contexts_emb, contexts_bias, abbr_inp_emb,
        #     tf.zeros([self.model_config.batch_size,1,1,1]), self.hparams,
        #     save_weights_to=weights)
        if self.model_config.encoder_mode == 't2t':
            encoder_output = transformer.transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            extra_loss = None
        elif self.model_config.encoder_mode == 'ut2t':
            encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            enc_ponder_times, enc_remainders = extra_output
            extra_loss = (self.hparams.act_loss_weight *
                          tf.reduce_mean(enc_ponder_times + enc_remainders))
        elif self.model_config.encoder_mode == 'abbr_ut2t':
            encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder(
                contexts_emb,
                contexts_bias,
                self.hparams,
                save_weights_to=weights)
            enc_ponder_times, enc_remainders = extra_output
            extra_loss = (self.hparams.act_loss_weight *
                          tf.reduce_mean(enc_ponder_times + enc_remainders))

            encoder_ouput2, extra_output2 = universal_transformer_util.universal_transformer_decoder(
                abbr_inp_emb, encoder_output,
                tf.zeros([self.model_config.batch_size, 1, 1, 1]),
                contexts_bias, self.hparams)
            enc_ponder_times2, enc_remainders2 = extra_output2
            extra_loss2 = (self.hparams.act_loss_weight *
                           tf.reduce_mean(enc_ponder_times2 + enc_remainders2))
            extra_loss += extra_loss2

        else:
            raise ValueError('Unknow encoder_mode.')

        return encoder_output, weights, extra_loss