def encode(self, inputs, target_space, hparams, features=None, losses=None, **kwargs): """Encode Universal Transformer inputs. It is similar to "transformer.encode", but it uses "universal_transformer_util.universal_transformer_encoder" instead of "transformer.transformer_encoder". Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: Unused. **kwargs: additional arguments to pass to encoder_function Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] encoder_extra_output: which is extra encoder output used in some variants of the model (e.g. in ACT, to pass the ponder-time to body) """ del losses inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder( inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) (encoder_output, encoder_extra_output) = ( universal_transformer_util.universal_transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights)) return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) update_decoder_attention_history(cache) cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2) modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): top = hparams.top.get("targets", modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache
def encode(self, inputs, target_space, hparams, features=None, losses=None): """Encode Universal Transformer inputs. It is similar to "transformer.encode", but it uses "universal_transformer_util.universal_transformer_encoder" instead of "transformer.transformer_encoder". Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: Unused. Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] encoder_extra_output: which is extra encoder output used in some variants of the model (e.g. in ACT, to pass the ponder-time to body) """ del losses inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder( inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) (encoder_output, encoder_extra_output) = ( universal_transformer_util.universal_transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights)) return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets") ) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache
def encode(self, features, input_key): hparams = self._hparams inputs = common_layers.flatten4d3d(features[input_key]) (encoder_input, encoder_self_attention_bias, _) = ( transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, input_key)) encoder_output = tf.reduce_mean(encoder_output, axis=1) return encoder_output
def encode(self, features, input_key): hparams = self._hparams inputs = common_layers.flatten4d3d(features[input_key]) (encoder_input, encoder_self_attention_bias, _) = ( transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, input_key)) encoder_output = tf.reduce_mean(encoder_output, axis=1) return encoder_output
def encode(self, inputs, target_space, hparams, features=None, losses=None): """Encode inputs using _encoder(). This performs the same way as transformer.Transformer.encode with the encoder portion replaced with _encoder(). Args: inputs: Input [batch_size, input_length, input_height, hidden_dim] tensor which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: Hyperparmeters for model. features: Optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: Unused list of losses. Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encodre-decoder attention. [batch_size, input_length] Raises: ValueError: If encoder type not found. """ inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder( inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = self._encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights) return encoder_output, encoder_decoder_attention_bias
def sim_encode(inputs, target_space, hparams, features): # inputs = tf.Print(inputs, [tf.shape(inputs)], "input", summarize=10) inputs = common_layers.flatten4d3d(inputs) (encoder_input, encoder_self_attention_bias, _) = (transformer.transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs")) positional_mean = tf.nn.l2_normalize(tf.reduce_mean(encoder_output, 1), 1) # out_norm = tf.norm(positional_mean) # positional_mean = tf.Print(positional_mean , [out_norm], "enc_out: (should be b_size**0.5) ", summarize=10) # positional_mean = tf.Print(positional_mean , [tf.shape(positional_mean)], "enc_out: (should be (b_size, h_size)) ", summarize=10) return positional_mean
def universal_transformer_encoder(inputs, target_space, hparams, features=None, make_image_summary=False): encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder( inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) [encoder_output, encoder_extra_output] = universal_transformer_util.universal_transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=None, make_image_summary=make_image_summary) # encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output
def body(self, features): hparams = self._hparams inputs = features["inputs"] target_space = features["target_space_id"] inputs = common_layers.flatten4d3d(inputs) (encoder_input, encoder_self_attention_bias, _) = (transformer.transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs")) encoder_output = encoder_output[:, :1, :] encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output
def encode(self, stories, questions, target_space, hparams, features=None): """Encode transformer inputs. Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. unused_features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encodre-decoder attention. [batch_size, input_length] """ inputs = tf.concat([stories, questions], axis=1) # inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, _ = ( transformer.transformer_prepare_encoder(inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) (encoder_output, extra_output) = universal_transformer_util.universal_transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights) return encoder_output, _, extra_output
def encode(self, inputs, target_space, hparams, features=None): """Encode transformer inputs. Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_extra_output: which is extra encoder output used in some variants of the model (e.g. in ACT, to pass the ponder-time to body) """ inputs = common_layers.flatten4d3d(inputs) (encoder_input, self_attention_bias, _) = (transformer.transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) (encoder_output, encoder_extra_output) = r_transformer_util.r_transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights) return encoder_output, encoder_extra_output
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs. [batch_size, input_length, 1, hidden_dim]. "targets": Target decoder outputs. [batch_size, decoder_length, 1, hidden_dim] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams losses = [] if self.has_input: # use melody-only as input features inputs = features["melody"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) else: encoder_output, encoder_decoder_attention_bias = (None, None) targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn( targets, hparams, features=features) # Not all subclasses of Transformer support keyword arguments related to # recurrent memory, so only pass these arguments if memory is enabled. decode_kwargs = {} if self.recurrent_memory_by_layer is not None: # TODO(kitaev): The chunk_number feature currently has the same shape as # "targets", but this is only for the purposes of sharing sharding code. # In fact every token within an example must have the same chunk number. chunk_number_each_token = tf.squeeze(features["chunk_number"], (-1, -2)) chunk_number_each_example = chunk_number_each_token[:, 0] # Uncomment the code below to verify that tokens within a batch share the # same chunk number: # with tf.control_dependencies([ # tf.assert_equal(chunk_number_each_token, # chunk_number_each_example[:, None]) # ]): # chunk_number_each_example = tf.identity(chunk_number_each_example) decode_kwargs = dict( recurrent_memory_by_layer=self.recurrent_memory_by_layer, chunk_number=chunk_number_each_example, ) decoder_output = self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding( features, "targets"), losses=losses, **decode_kwargs) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, {"extra_loss": tf.add_n(losses)} else: return ret
def perf_transformer_encode(encoder_function, inputs, target_space, hparams, baseline, attention_weights=None, features=None, losses=None, prepare_encoder_fn=None, **kwargs): """Encoding for performance autoencoder, which mean-aggregates across time. Args: encoder_function: the encoder function inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparameters for model. baseline: if True, does not mean-aggregate the encoder output. attention_weights: weight to store attention to. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: optional list onto which to append extra training losses prepare_encoder_fn: optional, alternative to transformer_prepare_encoder. **kwargs: additional arguments to pass to encoder_function Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] """ inputs = common_layers.flatten4d3d(inputs) if not prepare_encoder_fn: prepare_encoder_fn = transformer_prepare_encoder encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( prepare_encoder_fn(inputs, target_space, hparams, features=features, reuse_target_embedding=tf.AUTO_REUSE)) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=hparams.layer_prepostprocess_dropout, hparams=hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) attn_bias_for_padding = None # Otherwise the encoder will just use encoder_self_attention_bias. if hparams.unidirectional_encoder: attn_bias_for_padding = encoder_decoder_attention_bias encoder_output = encoder_function( encoder_input, self_attention_bias, hparams, name="encoder", nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled(), losses=losses, attn_bias_for_padding=attn_bias_for_padding, **kwargs) if not baseline: encoder_output = tf.math.reduce_mean(encoder_output, axis=1, keep_dims=True) encoder_decoder_attention_bias = tf.math.reduce_mean( encoder_decoder_attention_bias, axis=-1, keep_dims=True) return encoder_output, encoder_decoder_attention_bias
def mel_perf_transformer_encode(encoder_function, perf_inputs, mel_inputs, target_space, hparams, attention_weights=None, features=None, losses=None, prepare_encoder_fn=None, **kwargs): """Encode transformer inputs. Used for melody & performance autoencoder. Performance is mean-aggregated across time and combined with melody in a variety of different ways. Args: encoder_function: the encoder function perf_inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which will be flattened along the two spatial dimensions. mel_inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparameters for model. attention_weights: weight to store attention to. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: optional list onto which to append extra training losses prepare_encoder_fn: optional, alternative to transformer_prepare_encoder. **kwargs: additional arguments to pass to encoder_function Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] """ perf_inputs = common_layers.flatten4d3d(perf_inputs) mel_inputs = common_layers.flatten4d3d(mel_inputs) if not prepare_encoder_fn: prepare_encoder_fn = transformer_prepare_encoder perf_encoder_input, perf_self_attention_bias, perf_encdec_attention_bias = ( prepare_encoder_fn(perf_inputs, target_space, hparams, features=features, reuse_target_embedding=tf.AUTO_REUSE)) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=hparams.layer_prepostprocess_dropout, hparams=hparams) perf_encoder_input = tf.nn.dropout( perf_encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) perf_attn_bias_for_padding = None # Otherwise the encoder will just use encoder_self_attention_bias. if hparams.unidirectional_encoder: perf_attn_bias_for_padding = perf_encdec_attention_bias # do the same thing for melody mel_encoder_input, mel_self_attention_bias, mel_encdec_attention_bias = ( prepare_encoder_fn(mel_inputs, target_space, hparams, features=features, reuse_target_embedding=tf.AUTO_REUSE)) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=hparams.layer_prepostprocess_dropout, hparams=hparams) mel_encoder_input = tf.nn.dropout( mel_encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) mel_attn_bias_for_padding = None # Otherwise the encoder will just use encoder_self_attention_bias. if hparams.unidirectional_encoder: mel_attn_bias_for_padding = mel_encdec_attention_bias # use the proper encoder function for perf/melody perf_encoder_output = encoder_function( perf_encoder_input, perf_self_attention_bias, hparams, name="perf_encoder", nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled(), losses=losses, attn_bias_for_padding=perf_attn_bias_for_padding, **kwargs) # same thing for melody mel_encoder_output = encoder_function( mel_encoder_input, mel_self_attention_bias, hparams, name="mel_encoder", nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled(), losses=losses, attn_bias_for_padding=mel_attn_bias_for_padding, **kwargs) # concatenate the global mean vector/bias term with the full melody encoding perf_mean_vector = tf.math.reduce_mean(perf_encoder_output, axis=1, keep_dims=True) # different methods of aggregating over the performance + melody vectors! if hparams.aggregation == "sum": # add both mean performance and melody vectors together perf_mean_bias = tf.math.reduce_mean(perf_encdec_attention_bias, axis=-1, keep_dims=True) encoder_output = mel_encoder_output + perf_mean_vector encoder_decoder_attention_bias = mel_encdec_attention_bias + perf_mean_bias elif hparams.aggregation == "concat": # concatenate melody with mean-aggregated performance embedding stop_token = tf.zeros((1, 1, 384)) encoder_output = tf.concat( [mel_encoder_output, stop_token, perf_mean_vector], axis=1) perf_mean_bias = tf.math.reduce_mean(perf_encdec_attention_bias, axis=-1, keep_dims=True) stop_bias = tf.zeros((1, 1, 1, 1)) encoder_decoder_attention_bias = tf.concat( [mel_encdec_attention_bias, stop_bias, perf_mean_bias], axis=-1) elif hparams.aggregation == "tile": # tile performance embedding across each dimension of melody embedding! dynamic_val = tf.shape(mel_encoder_output)[1] shp = tf.convert_to_tensor([1, dynamic_val, 1], dtype=tf.int32) tiled_mean = tf.tile(perf_mean_vector, shp) encoder_output = tf.concat([mel_encoder_output, tiled_mean], axis=-1) encoder_decoder_attention_bias = mel_encdec_attention_bias else: NotImplementedError( "aggregation method must be in [sum, concat, tile].") return encoder_output, encoder_decoder_attention_bias
def hierarchical_attention_network_encoder( encoder_input, encoder_self_attention_bias, contexts, context_self_attention_biases, features, hparams, name="hierarchical_attention_network_encoder", save_weights_to=None, make_image_summary=True, losses=None): input_x = encoder_input context_xs = {} for context_name in contexts: context_xs[context_name] = contexts[context_name] context_paddings = {} context_nonpaddings = {} context_pad_removers = {} attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): input_padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) input_nonpadding = 1.0 - input_padding for context_name in context_self_attention_biases: context_paddings[ context_name] = common_attention.attention_bias_to_padding( context_self_attention_biases[context_name]) context_nonpaddings[ context_name] = 1.0 - context_paddings[context_name] input_pad_remover = None for context_name in context_paddings: context_pad_removers[context_name] = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): input_pad_remover = expert_utils.PadRemover(input_padding) for context_name in context_paddings: context_pad_removers[context_name] = expert_utils.PadRemover( context_paddings[context_name]) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> num_hidden_layers - 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", hparams.num_hidden_layers - 1) encoder_output = transformer_with_contexts_layers.transformer_encoder( input_x, encoder_self_attention_bias, temp_hparam, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=save_weights_to, make_image_summary=make_image_summary) context_encoded_outputs = {} for context_name in context_xs: context_encoded_outputs[ context_name] = transformer_with_contexts_layers.transformer_encoder( context_xs[context_name], context_self_attention_biases[context_name], hparams, nonpadding=features_to_nonpadding(features, context_name), save_weights_to=save_weights_to, make_image_summary=make_image_summary) with tf.variable_scope('word_abstraction', reuse=tf.AUTO_REUSE): encoder_word_level_query = common_layers.dense( encoder_output, hparams.hidden_size) # q_w = f_w(h_t) encoder_word_level_abstraction = {} for context_name in context_encoded_outputs: encoder_word_level_abstraction[ context_name] = transformer_with_contexts_layers.multihead_attention( common_layers.layer_preprocess( encoder_word_level_query, hparams), context_encoded_outputs[context_name], context_self_attention_biases[context_name], hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) # s^j, sentence_information = tf.concat([ encoder_word_level_abstraction[context_name] for context_name in encoder_word_level_abstraction ], axis=1) with tf.variable_scope('sentence_abstraction', reuse=tf.AUTO_REUSE): encoder_sentence_level_query = common_layers.dense( encoder_output, hparams.hidden_size) # q_s = f_s(h_t) context_padding = common_attention.embedding_to_padding( sentence_information) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) contextual_information = transformer_with_contexts_layers.multihead_attention( common_layers.layer_preprocess(encoder_sentence_level_query, hparams), sentence_information, ignore_padding, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d") ) # MultiHead(q_s, s^j), [batch, encoder_length, hidden_dim] contextual_information = common_layers.dense_relu_dense( contextual_information, hparams.filter_size, hparams.hidden_size) with tf.variable_scope('context_gating', reuse=tf.AUTO_REUSE): gate_lambda = tf.nn.sigmoid( common_layers.dense(contextual_information, hparams.hidden_size) + common_layers.dense(encoder_output, hparams.hidden_size)) encoder_output = gate_lambda * encoder_output + ( 1 - gate_lambda) * contextual_information return common_layers.layer_preprocess(encoder_output, hparams)
def hierarchical_context_encoder(encoder_input, encoder_self_attention_bias, contexts, context_self_attention_biases, features, hparams, name="discourse_aware_encoder", save_weights_to=None, make_image_summary=True, losses=None): input_x = encoder_input context_xs = {} for context_name in contexts: context_xs[context_name] = contexts[context_name] context_paddings = {} context_nonpaddings = {} context_pad_removers = {} attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): input_padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) input_nonpadding = 1.0 - input_padding for context_name in context_self_attention_biases: context_paddings[ context_name] = common_attention.attention_bias_to_padding( context_self_attention_biases[context_name]) context_nonpaddings[ context_name] = 1.0 - context_paddings[context_name] input_pad_remover = None for context_name in context_paddings: context_pad_removers[context_name] = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): input_pad_remover = expert_utils.PadRemover(input_padding) for context_name in context_paddings: context_pad_removers[context_name] = expert_utils.PadRemover( context_paddings[context_name]) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> num_hidden_layers - 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", hparams.num_hidden_layers - 1) encoder_output = transformer_with_contexts_layers.transformer_encoder( input_x, encoder_self_attention_bias, temp_hparam, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=save_weights_to, make_image_summary=make_image_summary) context_encoded_outputs = {} for context_name in context_xs: context_encoded_outputs[ context_name] = transformer_with_contexts_layers.transformer_encoder( context_xs[context_name], context_self_attention_biases[context_name], temp_hparam, nonpadding=features_to_nonpadding(features, context_name), save_weights_to=save_weights_to, make_image_summary=make_image_summary) with tf.variable_scope("hierarchical_context_encoder", reuse=tf.AUTO_REUSE): for context_name in context_encoded_outputs: # self attention feed-forward _y = ffn_self_attention_layer( context_encoded_outputs[context_name], hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, name="attentive_sum") # mean over sequence length context_encoded_outputs[context_name] = tf.reduce_mean( _y, axis=1, keep_dims=True) encoded_contexts = [ context_encoded_outputs[context_name] for context_name in context_encoded_outputs ] encoded_contexts = tf.concat(encoded_contexts, axis=1) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", 1) context_padding = common_attention.embedding_to_padding( encoded_contexts) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) encoded_contexts = transformer_encoder(encoded_contexts, ignore_padding, temp_hparam) with tf.variable_scope("encoder/layer_%d" % hparams.num_hidden_layers, reuse=tf.AUTO_REUSE): with tf.variable_scope("context_input_attention"): context_padding = common_attention.embedding_to_padding( encoded_contexts) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) _y = common_attention.multihead_attention( common_layers.layer_preprocess(encoder_output, hparams), encoded_contexts, ignore_padding, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) encoded_contexts = common_layers.layer_postprocess( encoder_output, _y, hparams) with tf.variable_scope("input_self_attention"): _y = common_attention.multihead_attention( common_layers.layer_preprocess(encoder_output, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) encoder_output = common_layers.layer_postprocess( encoder_output, _y, hparams) with tf.variable_scope("gated_sum"): _depth = common_layers.shape_list(encoder_output)[-1] gate = tf.layers.dense(tf.concat( [encoded_contexts, encoder_output], axis=-1), _depth, activation=tf.nn.sigmoid) if save_weights_to: save_weights_to["gated_sum"] = gate encoder_output = gate * encoder_output + ( 1. - gate) * encoded_contexts with tf.variable_scope("ffn"): _y = transformer_ffn_layer(common_layers.layer_preprocess( encoder_output, hparams), hparams, input_pad_remover, conv_padding="SAME", nonpadding_mask=input_nonpadding, losses=losses) encoder_output = common_layers.layer_postprocess( encoder_output, _y, hparams) return common_layers.layer_preprocess(encoder_output, hparams)
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs. [batch_size, input_length, 1, hidden_dim]. "targets": Target decoder outputs. [batch_size, decoder_length, 1, hidden_dim] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams losses = [] if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) else: encoder_output, encoder_decoder_attention_bias = (None, None) # here we replace "original" encoder_output with bert's encoder_output = self.bert.get_sequence_output( ) # [batch_size, seq_length, hidden_size] targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer.transformer_prepare_decoder( targets, hparams, features=features) decoder_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets"), losses=losses) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, {"extra_loss": tf.add_n(losses)} else: return ret
def encode(self, inputs, target_space, hparams, features=None, losses=None, **kwargs): """Encode Universal Transformer inputs. It is similar to "transformer.encode", but it uses "universal_transformer_util.universal_transformer_encoder" instead of "transformer.transformer_encoder". Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: Unused. **kwargs: additional arguments to pass to encoder_function Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] encoder_extra_output: which is extra encoder output used in some variants of the model (e.g. in ACT, to pass the ponder-time to body) """ #### ## DEBUG #### # with open("invertible_UT_params.json", "w") as f: # json.dump(dict(hparams.__dict__), f, default=lambda o: '<not serializable>', sort_keys=True, # indent=4, separators=(',', ': ')) # sys.exit() del losses inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder(inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) (encoder_output, encoder_extra_output) = (invertible_UT_encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights)) for var in tf.trainable_variables(): print(var) return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ tf.logging.info("Using PgScratch BODY function.") hparams = self._hparams losses = {} inputs = features["inputs"] target_space = features["target_space_id"] # encoder_output: <tf.float32>[batch_size, input_length, hidden_dim] # encoder_decoder_attention_bias: <tf.float32>[batch_size, input_length] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) with tf.variable_scope("knowledge"): with tf.name_scope("knowledge_encoding"): # Encode knowledge. # <tf.float32>[batch_size, triple_num, emb_dim] fact_embedding, fact_lengths = self.encode_knowledge_bottom( features) tf.logging.info("Encoded knowledge") with tf.name_scope("knowledge_selection_and_loss"): # Compute knowledge selection and loss. triple_logits, avg_triple_selection_loss, knowledge_encoder_output, transe_loss = self.compute_knowledge_selection_and_loss( features, encoder_output, fact_embedding, fact_lengths, hparams.margin, hparams.num_negative_samples) losses["kb_loss"] = avg_triple_selection_loss losses["transe_loss"] = transe_loss if hparams.attend_kb: tf.logging.info("ATTEND_KB is ACTIVE") with tf.name_scope("knowledge_attention"): knowledge_padding = tf.zeros_like(triple_logits, dtype=tf.float32) knowledge_attention_bias = common_attention.attention_bias_ignore_padding( knowledge_padding) encoder_output = tf.concat( [knowledge_encoder_output, encoder_output], 1) encoder_decoder_attention_bias = tf.concat( [knowledge_attention_bias, encoder_decoder_attention_bias], -1) else: tf.logging.info("ATTEND_KB is INACTIVE") targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias ) = transformer.transformer_prepare_decoder(targets, hparams, features=features) decode_kwargs = {} decoder_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets"), losses=losses, **decode_kwargs) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, losses else: return ret
def body(self, features): """R-Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = self.encode(inputs, target_space, hparams, features=features) else: (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = (None, None, (None, None)) targets = features["targets"] targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias ) = transformer.transformer_prepare_decoder(targets, hparams, features=features) decoder_output, dec_extra_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets")) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0: if self.has_input: enc_ponder_times, enc_remainders = enc_extra_output enc_act_loss = ( hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) else: enc_act_loss = 0.0 (dec_ponder_times, dec_remainders) = dec_extra_output dec_act_loss = (hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) act_loss = enc_act_loss + dec_act_loss tf.summary.scalar("act_loss", act_loss) return decoder_output, {"act_loss": act_loss} return decoder_output
def body(self, features): """Universal Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams if hparams.add_position_timing_signal: # Turning off addition of positional embedding in the encoder/decoder # preparation as we do it in the beginning of each step. hparams.pos = None if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = self.encode( inputs, target_space, hparams, features=features) else: (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = (None, None, (None, None)) targets = features["targets"] targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias) = transformer.transformer_prepare_decoder( targets, hparams, features=features) decoder_output, dec_extra_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets")) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0: if self.has_input: enc_ponder_times, enc_remainders = enc_extra_output enc_act_loss = ( hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) else: enc_act_loss = 0.0 (dec_ponder_times, dec_remainders) = dec_extra_output dec_act_loss = ( hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) act_loss = enc_act_loss + dec_act_loss tf.contrib.summary.scalar("act_loss", act_loss) return decoder_output, {"act_loss": act_loss} return decoder_output
def body(self, features): """CopyTransformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "targets_*": Additional decoder outputs to generate, for copying and pointing; [batch_size, decoder_length] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams losses = [] inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) if "targets_actions" in features: targets = features["targets_actions"] else: tf.logging.warn( "CopyTransformer must be used with a SemanticParsing problem with a ShiftReduceGrammar; bad things will happen otherwise" ) targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams, features=features) decoder_output = self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding( features, "targets"), losses=losses) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} decoder_output = tf.reshape(decoder_output, targets_shape) body_output = dict() target_modality = self._problem_hparams.target_modality \ if self._problem_hparams else {"targets": None} assert hparams.pointer_layer in ("attentive", "decaying_attentive") for key, modality in target_modality.items(): if isinstance(modality, CopyModality): with tf.variable_scope("copy_layer/" + key): if hparams.pointer_layer == "decaying_attentive": output_layer = DecayingAttentivePointerLayer( encoder_output) else: output_layer = AttentivePointerLayer(encoder_output) scores = output_layer(decoder_output) scores += encoder_decoder_attention_bias body_output[key] = scores else: body_output[key] = decoder_output if losses: return body_output, {"extra_loss": tf.add_n(losses)} else: return body_output
def body(self, features): """Universal Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams if hparams.add_postion_timing_signal: # Turning off addition of positional embedding in the encoder/decoder # preparation as we do it in the beginning of each step. hparams.pos = None if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = self.encode(inputs, target_space, hparams, features=features) else: (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = (None, None, (None, None)) targets = features["targets"] targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias ) = transformer.transformer_prepare_decoder(targets, hparams, features=features) decoder_output, dec_extra_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets")) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: print('returning attention loss') attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0: print('returning act loss') if self.has_input: enc_ponder_times, enc_remainders = enc_extra_output enc_act_loss = ( hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) else: enc_act_loss = 0.0 (dec_ponder_times, dec_remainders) = dec_extra_output dec_act_loss = (hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) act_loss = enc_act_loss + dec_act_loss tf.contrib.summary.scalar("act_loss", act_loss) return decoder_output, {"act_loss": act_loss} #grads = get_grads_and_vars(attention_loss) # dec_out_and_grads = tf.concat([decoder_output, grads], 1) # ¿0 or 1? access_output, access_state = self._access(decoder_output, dec_extra_output) return decoder_output, DNCState(access_output=access_output, access_state=access_state, controller_state=dec_extra_output)
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs. [batch_size, input_length, 1, hidden_dim]. "targets": Target decoder outputs. [batch_size, decoder_length, 1, hidden_dim] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams losses = [] if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) else: encoder_output, encoder_decoder_attention_bias = (None, None) targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) left_decoder_input, left_decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams, features=features) right_decoder_input, right_decoder_self_attention_bias = transformer_prepare_decoder_right( targets, hparams, features=features) non_pad = nonpadding = features_to_nonpadding(features, "targets") with tf.variable_scope("left_decoder"): left_decoder_output = self.decode(left_decoder_input, encoder_output, encoder_decoder_attention_bias, left_decoder_self_attention_bias, hparams, nonpadding=non_pad, losses=losses) with tf.variable_scope("right_decoder"): right_decoder_output = self.decode( right_decoder_input, encoder_output, encoder_decoder_attention_bias, right_decoder_self_attention_bias, hparams, nonpadding=non_pad, losses=losses) decoder_output = transformer_bidirectional_joint_decoder( tf.squeeze(left_decoder_output, axis=2), tf.squeeze(right_decoder_output, axis=2), encoder_output, encoder_decoder_attention_bias, hparams, nonpadding=non_pad, save_weights_to=self.attention_weights, losses=losses) decoder_output = tf.expand_dims(decoder_output, axis=2) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, {"extra_loss": tf.add_n(losses)} else: return ret
def body(self, features, original_features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams snippets = features.get(searchqa_problem.FeatureNames.SNIPPETS) questions = features.get(searchqa_problem.FeatureNames.QUESTION) target_space = features["target_space_id"] with tf.variable_scope('input'): # [batch_size, search_results_len, embed_sz] encoded_snippets = self.inputs_encoding( input=snippets, original_input=original_features.get( searchqa_problem.FeatureNames.SNIPPETS), initializer=tf.constant_initializer(1.0), scope='snippets_encoding') # [batch_size, 1, embed_sz] encoded_question = self.inputs_encoding( input=questions, original_input=original_features.get( searchqa_problem.FeatureNames.QUESTION), initializer=tf.constant_initializer(1.0), scope='question_encoding') # Concat snippets and questions to creat the inputs inputs = tf.concat([encoded_snippets, encoded_question], axis=1) # the input is 4D by default and it gets squeezed from 4D to 3D in the # encode function, so we need to make it 4D by inserting channel dim. inputs = tf.expand_dims(inputs, axis=2) losses = [] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer.transformer_prepare_decoder( targets, hparams, features=features) decoder_output = self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding( features, "targets"), losses=losses) ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, {"extra_loss": tf.add_n(losses)} else: return ret