def decode(self, decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, cache=None, decode_loop_step=None, nonpadding=None, losses=None): """Decode Universal Transformer outputs from encoder representation. It is similar to "transformer.decode", but it uses "universal_transformer_util.universal_transformer_decoder" instead of "transformer.transformer_decoder". Args: decoder_input: inputs to bottom of the model. [batch_size, decoder_length, hidden_dim] encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] decoder_self_attention_bias: Bias and mask weights for decoder self-attention. [batch_size, decoder_length] hparams: hyperparmeters for model. cache: Unimplemented. decode_loop_step: Unused. nonpadding: optional Tensor with shape [batch_size, decoder_length] losses: Unused. Returns: Tuple of: Final decoder representation. [batch_size, decoder_length, hidden_dim] encoder_extra_output: which is extra encoder output used in some variants of the model (e.g. in ACT, to pass the ponder-time to body) """ del decode_loop_step del losses # TODO(dehghani): enable caching. del cache decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) # No caching in Universal Transformers! (decoder_output, dec_extra_output) = ( universal_transformer_util.universal_transformer_decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, nonpadding=nonpadding, save_weights_to=self.attention_weights)) # Expand since t2t expects 4d tensors. return tf.expand_dims(decoder_output, axis=2), dec_extra_output
def decode(self, decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, cache=None, decode_loop_step=None, nonpadding=None, losses=None): """Decode Universal Transformer outputs from encoder representation. It is similar to "transformer.decode", but it uses "universal_transformer_util.universal_transformer_decoder" instead of "transformer.transformer_decoder". Args: decoder_input: inputs to bottom of the model. [batch_size, decoder_length, hidden_dim] encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] decoder_self_attention_bias: Bias and mask weights for decoder self-attention. [batch_size, decoder_length] hparams: hyperparmeters for model. cache: Unimplemented. decode_loop_step: Unused. nonpadding: optional Tensor with shape [batch_size, decoder_length] losses: Unused. Returns: Tuple of: Final decoder representation. [batch_size, decoder_length, hidden_dim] encoder_extra_output: which is extra encoder output used in some variants of the model (e.g. in ACT, to pass the ponder-time to body) """ del decode_loop_step del losses # TODO(dehghani): enable caching. del cache decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) # No caching in Universal Transformers! (decoder_output, dec_extra_output) = ( universal_transformer_util.universal_transformer_decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, nonpadding=nonpadding, save_weights_to=self.attention_weights)) # Expand since t2t expects 4d tensors. return tf.expand_dims(decoder_output, axis=2), dec_extra_output
def decode_inputs_to_outputs(self, decoder_embed_inputs, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, obj_tensors=None): if self.hparams.pos == 'timing': decoder_embed_inputs = common_attention.add_timing_signal_1d(decoder_embed_inputs) print('Use positional encoding in decoder text.') decoder_embed_inputs = self.update_decoder_embedding(decoder_embed_inputs, score, self.model_config.beam_search_size) decoder_attn_bias = common_attention.attention_bias_lower_triangle(tf.shape(decoder_embed_inputs)[1]) decoder_embed_inputs = tf.nn.dropout(decoder_embed_inputs, 1.0 - self.hparams.layer_prepostprocess_dropout) if 'direct' in self.model_config.memory: assert 'direct_bert_output' in obj_tensors decoder_output = transformer.transformer_multi_decoder( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, obj_tensors['direct_bert_output'], obj_tensors['direct_bert_bias'], self.hparams, save_weights_to=obj_tensors, direct_mode=self.model_config.direct_mode) if self.model_config.npad_mode == 'static_seq': decoder_output = tf.nn.conv1d(decoder_output, obj_tensors['npad_w'], 1, 'SAME') return decoder_output, decoder_output, None elif 'rule' in self.model_config.memory: decoder_output, contexts = transformer.transformer_decoder_contexts( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, self.hparams) # encoder_gate_w = tf.get_variable('encoder_gate_w', shape=( # 1, self.model_config.dimension, 1)) # encoder_gate_b = tf.get_variable('encoder_gate_b', shape=(1, 1, 1)) # encoder_gate = tf.tanh(encoder_gate_b + tf.nn.conv1d(encoder_outputs, encoder_gate_w, 1, 'SAME')) # encoder_context_outputs = tf.expand_dims(tf.reduce_mean(encoder_outputs * encoder_gate, axis=1), axis=1) cur_context = contexts[0] #tf.concat(contexts, axis=-1) cur_mem_contexts = tf.stack(self.embedding_fn(rule_id_input_placeholder, mem_contexts), axis=1) cur_mem_outputs = tf.stack(self.embedding_fn(rule_id_input_placeholder, mem_outputs), axis=1) cur_mem_contexts = tf.reshape(cur_mem_contexts, [self.model_config.batch_size, self.model_config.max_target_rule_sublen*self.model_config.max_cand_rules, self.model_config.dimension]) cur_mem_outputs = tf.reshape(cur_mem_outputs, [self.model_config.batch_size, self.model_config.max_target_rule_sublen*self.model_config.max_cand_rules, self.model_config.dimension]) # bias = tf.expand_dims( # -1e9 * tf.to_float(tf.equal(tf.stack(rule_id_input_placeholder, axis=1), 0)), # axis=1) # weights = tf.nn.softmax(bias + tf.matmul(cur_context, cur_mem_contexts, transpose_b=True)) weights = tf.nn.softmax(tf.matmul(cur_context, cur_mem_contexts, transpose_b=True)) mem_output = tf.matmul(weights, cur_mem_outputs) # trainable_mem = 'stopgrad' not in self.model_config.rl_configs temp_output = tf.concat((decoder_output, mem_output), axis=-1) # w_u = tf.get_variable('w_ffn', shape=( # 1, self.model_config.dimension*2, self.model_config.dimension), trainable=trainable_mem) # b_u = tf.get_variable('b_ffn', shape=( # 1, 1, self.model_config.dimension), trainable=trainable_mem) # w_u.reuse_variables() # b_u.reuse_variables() # tf.get_variable_scope().reuse_variables() w_t = tf.get_variable('w_ffn', shape=( 1, self.model_config.dimension*2, self.model_config.dimension), trainable=True) b_t = tf.get_variable('b_ffn', shape=( 1, 1, self.model_config.dimension), trainable=True) # w = tf.cond(tf.equal(tf.mod(self.global_step, 2), 0), lambda: w_t, lambda: w_u) # b = tf.cond(tf.equal(tf.mod(self.global_step, 2), 0), lambda: b_t, lambda: b_u) mem_output = tf.nn.conv1d(temp_output, w_t, 1, 'SAME') + b_t g = tf.greater(global_step, tf.constant(self.model_config.memory_prepare_step, dtype=tf.int64)) final_output = tf.cond(g, lambda: mem_output, lambda: decoder_output) return final_output, decoder_output, cur_context else: if self.model_config.architecture == 'ut2t': (decoder_output, decoder_extra_output) = universal_transformer_util.universal_transformer_decoder( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, self.hparams, save_weights_to=obj_tensors) dec_ponder_times, dec_remainders = decoder_extra_output extra_dec_loss = ( self.hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) if self.is_train: obj_tensors['extra_decoder_loss'] = extra_dec_loss else: decoder_output = transformer.transformer_decoder( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, self.hparams, save_weights_to=obj_tensors, npad_mode=self.model_config.npad_mode) final_output = decoder_output return final_output, decoder_output, None
def context_encoder(self, contexts_emb, contexts, abbr_inp_emb=None): """ :param contexts_emb: a tensor of [batch_size, max_context_len, emb_dim] :param contexts: a list of [max_context_len, batch_size] :param abbr_inp_emb: a tensor of [batch_size, context_len, emb_dim], in transformer_abbr_encoder :return: encoder_output: [batch_size, context_len, channel_dim] weights: a list of multihead weights, num_layer elements, each of which is [batch_size, num_head, context_len, context_len] extra_loss: None """ weights = {} # Create an bias tensor as mask (big neg values for padded part), input=[batch_size, context_len], output=[batch_size, 1, 1, context_len] contexts_bias = common_attention.attention_bias_ignore_padding( tf.to_float( tf.equal(tf.stack(contexts, axis=1), self.voc.encode(constant.PAD)))) # add dropout to context input [batch_size, max_context_len, emb_dim] contexts_emb = tf.nn.dropout( contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout) # get the output vector of transformer, [batch_size, context_len, channel_dim] # encoder_ouput = transformer.transformer_encoder_abbr( # contexts_emb, contexts_bias, abbr_inp_emb, # tf.zeros([self.model_config.batch_size,1,1,1]), self.hparams, # save_weights_to=weights) if self.model_config.encoder_mode == 't2t': encoder_output = transformer.transformer_encoder( contexts_emb, contexts_bias, self.hparams, save_weights_to=weights) extra_loss = None elif self.model_config.encoder_mode == 'ut2t': encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder( contexts_emb, contexts_bias, self.hparams, save_weights_to=weights) enc_ponder_times, enc_remainders = extra_output extra_loss = (self.hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) elif self.model_config.encoder_mode == 'abbr_ut2t': encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder( contexts_emb, contexts_bias, self.hparams, save_weights_to=weights) enc_ponder_times, enc_remainders = extra_output extra_loss = (self.hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) encoder_ouput2, extra_output2 = universal_transformer_util.universal_transformer_decoder( abbr_inp_emb, encoder_output, tf.zeros([self.model_config.batch_size, 1, 1, 1]), contexts_bias, self.hparams) enc_ponder_times2, enc_remainders2 = extra_output2 extra_loss2 = (self.hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times2 + enc_remainders2)) extra_loss += extra_loss2 else: raise ValueError('Unknow encoder_mode.') return encoder_output, weights, extra_loss