def attention_lm_moe_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly baises for diagonal alignments pad_remover (expert_utils.PadRemover): an util object to remove padding """ targets_pad_mask = common_attention.embedding_to_padding(targets) with tf.name_scope("pad_remover"): # Because of the shift_right, the <eos> token will be considered as # padding. In practice, it doesn't really matter, due to the triangular # mask, this token should never be attended. pad_remover = expert_utils.PadRemover(targets_pad_mask) if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepend_inputs_full_attention( targets_pad_mask)) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias, pad_remover)
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def prepare_decoder(targets, hparams): """Prepare decoder for images.""" targets_shape = common_layers.shape_list(targets) channels = hparams.num_channels curr_infer_length = None # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. # At inference, they are [batch, curr_infer_length, 1, 1] if hparams.mode == tf.contrib.learn.ModeKeys.INFER: curr_infer_length = targets_shape[1] if hparams.block_raster_scan: assert hparams.img_len*channels % hparams.query_shape[1] == 0 assert hparams.img_len % hparams.query_shape[0] == 0 total_block_width = hparams.img_len*channels # Decoding is in block raster scan order. We divide the image into # hparams.query_shape blocks and then decode each block in raster scan. # To make that compatible with our inference pipeline, pad the target so # that rows is a multiple of query_shape and columns is a multiple of # hparams.img_len*channels curr_infer_length = targets_shape[1] block_padding_factor = total_block_width * hparams.query_shape[0] targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % block_padding_factor], [0, 0], [0, 0]]) num_blocks = total_block_width // hparams.query_shape[1] # Reshape the image to represent blocks target_blocks = tf.reshape( targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0], hparams.query_shape[1]]) # Transpose to read the image in 2D fashion. targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4]) else: # add padding to make sure the size of targets is a multiple of img_height # times number of channels. This is needed for positional encodings and # for doing the RGB lookup. padding_factor = channels * hparams.img_len targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]]) targets = tf.reshape(targets, [targets_shape[0], -1, hparams.img_len, channels]) # Preprocess image x = prepare_image(targets, hparams, name="dec_channels") x_shape = common_layers.shape_list(x) if (hparams.dec_attention_type == AttentionType.LOCAL_2D or hparams.dec_attention_type == AttentionType.LOCAL_BLOCK): x = common_attention.right_shift_blockwise(x, hparams.query_shape) x = add_pos_signals(x, hparams, "dec_pos") else: # Add position signals x = tf.reshape(x, [targets_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) x = common_layers.shift_right_3d(x) x = tf.reshape(x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size]) x = add_pos_signals(x, hparams, "dec_pos") x = common_layers.cast_like(x, targets) return x, x_shape[1], x_shape[2]
def prepare_decoder(targets, hparams): """Prepare decoder for images.""" targets_shape = common_layers.shape_list(targets) channels = hparams.num_channels curr_infer_length = None # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. # At inference, they are [batch, curr_infer_length, 1, 1] if (hparams.mode == tf.contrib.learn.ModeKeys.INFER and hparams.block_raster_scan): curr_infer_length = targets_shape[1] if hparams.block_raster_scan: assert hparams.img_len*channels % hparams.query_shape[1] == 0 assert hparams.img_len % hparams.query_shape[0] == 0 total_block_width = hparams.img_len*channels # Decoding is in block raster scan order. We divide the image into # hparams.query_shape blocks and then decode each block in raster scan. # To make that compatible with our inference pipeline, pad the target so # that rows is a multiple of query_shape and columns is a multiple of # hparams.img_len*channels curr_infer_length = targets_shape[1] block_padding_factor = total_block_width * hparams.query_shape[0] targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % block_padding_factor], [0, 0], [0, 0]]) num_blocks = total_block_width // hparams.query_shape[1] # Reshape the image to represent blocks target_blocks = tf.reshape( targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0], hparams.query_shape[1]]) # Transpose to read the image in 2D fashion. targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4]) else: # add padding to make sure the size of targets is a multiple of img_height # times number of channels. This is needed for positional encodings and # for doing the RGB lookup. padding_factor = channels * hparams.img_len targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]]) targets = tf.reshape(targets, [targets_shape[0], -1, hparams.img_len, channels]) # Preprocess image x = prepare_image(targets, hparams, name="dec_channels") x_shape = common_layers.shape_list(x) if (hparams.dec_attention_type == AttentionType.LOCAL_2D or hparams.dec_attention_type == AttentionType.LOCAL_BLOCK): x = common_attention.right_shift_blockwise(x, hparams.query_shape) x = add_pos_signals(x, hparams, "dec_pos") else: # Add position signals x = tf.reshape(x, [targets_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) x = common_layers.shift_right_3d(x) x = tf.reshape(x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size]) x = add_pos_signals(x, hparams, "dec_pos") return x, x_shape[1], x_shape[2]
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in decoder self-attention """ if hparams.causal_decoder_self_attention: # Causal attention. if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepend_inputs_full_attention( common_attention.embedding_to_padding(targets))) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) else: # Full attention. decoder_padding = common_attention.embedding_to_padding(targets) decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(decoder_padding)) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d(decoder_input) elif hparams.pos == "emb": decoder_input = common_attention.add_positional_embedding( decoder_input, hparams.max_length, "targets_positional_embedding", targets_position) if hparams.activation_dtype == "bfloat16": decoder_self_attention_bias = tf.cast(decoder_self_attention_bias, tf.bfloat16) return (decoder_input, decoder_self_attention_bias)
def prepare_decoder(targets, target_space_emb): """Prepare decoder.""" decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) target_space_emb = tf.reshape(target_space_emb, [1, 1, -1]) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1]) decoder_input = common_layers.shift_right_3d(targets, pad_value=target_space_emb) decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def prepare_decoder(targets, target_space_emb): """Prepare decoder.""" decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) target_space_emb = tf.reshape(target_space_emb, [1, 1, -1]) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1]) decoder_input = common_layers.shift_right_3d( targets, pad_value=target_space_emb) decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) #if hparams.pos == "timing": # if targets_position is not None: # decoder_input = common_attention.add_timing_signal_1d_given_position( # decoder_input, targets_position) # else: # decoder_input = common_attention.add_timing_signal_1d(decoder_input) raw_decoder_input = common_layers.shift_right(features['targets_raw']) terminal_decoder_bias, nonterminal_decoder_bias = _get_t_nt_bias( raw_decoder_input, hparams, decoder_self_attention_bias) pop_decoder_bias = _get_pop_bias(raw_decoder_input, hparams) raw_decoder_input = tf.squeeze(raw_decoder_input, axis=[-2, -1]) pos_signals = generate_positional_signals(raw_decoder_input, hparams, terminal_decoder_bias, nonterminal_decoder_bias) pos_embeddings = generate_positional_embeddings(pos_signals, hparams.decoder_pos, hparams) if "sum" in hparams.decoder_pos_integration: decoder_input = decoder_input + pos_embeddings elif "ffn" in hparams.decoder_pos_integration: with tf.variable_scope("decoder_pos_ffn"): decoder_input = tf.concat([decoder_input, pos_embeddings], axis=2) decoder_input = transformer_ffn_layer(decoder_input, hparams, conv_padding="LEFT") return (decoder_input, decoder_self_attention_bias, terminal_decoder_bias, nonterminal_decoder_bias, pop_decoder_bias, pos_signals)
def transformer_prepare_decoder(targets, hparams): """Copied from tensor2tensor.models.transformer.""" decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def model_fn_body(self, features): hparams = self._hparams if not hparams.attention or hparams.attention_architecture != "standard": raise ValueError( "Layer-by-layer version of TF-NMT only available for " "the 'standard' attention model architecture.") inputs, inputs_length = usr_utils.get_feature_with_length( features, "inputs") target_roots, target_roots_length = usr_utils.get_feature_with_length( features, "target_roots") targets, targets_length = usr_utils.get_feature_with_length( features, "targets") # We need to do +1 for inference since get_feature_with_length() # may not have direct access to sequence lengths and returns # a length of 0 for the first inference step. if hparams.mode == tf.contrib.learn.ModeKeys.INFER: targets_length = targets_length + 1 # input lengths of 0 breaks things inputs_length = tf.maximum(inputs_length, 1) target_roots_length = tf.maximum(target_roots_length, 1) # Shift targets right to use them as input targets = common_layers.shift_right_3d(targets) # Manage POP signals if hparams.target_root_attention == "pop": raw_targets = tf.squeeze(tf.squeeze(features["raw_targets"], axis=2), axis=2) targets_is_pop = tf.equal(raw_targets, hparams.pop_id) else: targets_is_pop = None iterator = TFNmtLayerbylayerInput( initializer=None, source=inputs, target_input=targets, target_input_is_pop=targets_is_pop, target_output=None, # Loss is computed in T2T target_root=target_roots, source_sequence_length=inputs_length, target_sequence_length=targets_length, target_root_sequence_length=target_roots_length) tfnmt_model = TFNmtLayerbylayerModel( hparams_helper.convert_to_tfnmt_hparams(hparams), iterator=iterator, mode=tf.contrib.learn.ModeKeys. EVAL, # We use eval graph for training source_vocab_table=FakeVocabTable(), target_vocab_table=FakeVocabTable()) decoder_output = tfnmt_model.logits return tf.expand_dims(decoder_output, axis=2)
def transformer_edit_ops_layer(decoder_input, hparams, encoder_output, features, cache=None, decode_loop_step=None, nonpadding=None, losses=None, layer_collection=None): """Layer that conditions on the error tag and start and end token pointers.""" if isinstance(encoder_output, list): # Select forward encoder encoder_output = encoder_output[0] with tf.variable_scope("edit_ops_layer"): with tf.variable_scope("ffn"): x = decoder_input # Shorthand for layer preprocessing # pylint: disable=g-long-lambda preproc = lambda z: common_layers.layer_preprocess( z, hparams, layer_collection=layer_collection) # pylint: enable=g-long-lambda feedback_start_token = (hparams.use_start_token or not hparams.feedback_end_token) if feedback_start_token: start_token = _pointer_feedback( features["targets_start_token"], encoder_output, shift=hparams.feedback_end_token) if hparams.feedback_end_token: end_token = _pointer_feedback(features["targets_end_token"], encoder_output) layer_inputs = [preproc(x)] if hparams.use_error_tags: error_tags = common_layers.shift_right_3d( common_layers.flatten4d3d(features["targets_error_tag"])) layer_inputs.append(preproc(error_tags)) if feedback_start_token: layer_inputs.append(start_token) if hparams.feedback_end_token: layer_inputs.append(end_token) y = transformer_layers.transformer_ffn_layer( tf.concat(layer_inputs, axis=2), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding, losses=losses, cache=cache, decode_loop_step=decode_loop_step, layer_collection=layer_collection) x = common_layers.layer_postprocess(x, y, hparams) return x
def body(self, features): """ Args: features["inputs"]: features["targets"]: tensors with shape [batch_size, ..., hidden_size] Return: decoder_outputs: pre-softmax activations of same size as inputs I assume that the input is a time series such that input size is [batch_size,sequence_length,hidden_size] """ inputs = features["inputs"] targets = features["targets"] #tensor2tensor provides 4d tensors and axis=2 is useless #so I remove it for ease of handling original_shape = common_layers.shape_list(inputs) squeeze_shape_inputs = [x for x in \ common_layers.shape_list(inputs) if x != 1] squeeze_shape_targets = [x for x in \ common_layers.shape_list(targets) if x != 1] #squeeze unneeded dimensions inputs = tf.reshape(inputs, squeeze_shape_inputs) targets = tf.reshape(targets, squeeze_shape_targets) decoder_inputs = common_layers.shift_right_3d(targets) #encoder bias causes padding to be ignored inputs_embedding_mask = common_attention.\ embedding_to_padding(inputs) self.encoder_attention_bias = common_attention.\ attention_bias_ignore_padding(inputs_embedding_mask) #decoder bias causes targets to only attend to #previous positions (and itself) self.decoder_attention_bias = \ common_attention.attention_bias_lower_triangle\ (common_layers.shape_list(targets)[1]) #process encoder and save the result for decoder to use #and process decoder self.encoder_outputs = self.adaptive_computation(inputs, self.encode) outputs = self.adaptive_computation(decoder_inputs, self.decode) #reshape output back to 4d outputs = tf.reshape(outputs, original_shape) return outputs
def transformer_fast_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_position_forward_mask: mask Tensor for position-forward. [1, t, 1] """ length = tf.shape(targets)[1] decoder_position_forward_mask = 1. / tf.expand_dims( tf.expand_dims(tf.to_float(tf.range(length)) + 1., 0), -1) # [1, t, 1] decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_position_forward_mask)
def transformer_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) return (decoder_input, decoder_self_attention_bias)
def transformer_edit_ops_layer( decoder_input, hparams, encoder_output, features, cache=None, decode_loop_step=None, nonpadding=None, losses=None, layer_collection=None, ): """Layer that conditions on the error tag and start and end token pointers.""" if isinstance(encoder_output, list): # Select forward encoder encoder_output = encoder_output[0] with tf.variable_scope('edit_ops_layer'): with tf.variable_scope('ffn'): x = decoder_input # Shorthand for layer preprocessing # pylint: disable=g-long-lambda preproc = lambda z: common_layers.layer_preprocess( z, hparams, layer_collection=layer_collection) # pylint: enable=g-long-lambda layer_inputs = [preproc(x)] error_tags = common_layers.shift_right_3d( common_layers.flatten4d3d(features['targets_error_tag'])) layer_inputs.append(preproc(error_tags)) y = transformer_layers.transformer_ffn_layer( tf.concat(layer_inputs, axis=2), hparams, conv_padding='LEFT', nonpadding_mask=nonpadding, losses=losses, cache=cache, decode_loop_step=decode_loop_step, layer_collection=layer_collection, ) x = common_layers.layer_postprocess(x, y, hparams) return x
def attention_lm_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly baises for diagonal alignments """ if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepended( common_attention.embedding_to_padding(targets))) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d( decoder_input) return (decoder_input, decoder_self_attention_bias)
def attention_lm_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly baises for diagonal alignments """ if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepended( common_attention.embedding_to_padding(targets))) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) # decoder_input = tf.Print(decoder_input, [tf.shape(decoder_input)], # summarize=1000, message="decoder_input") # decoder_self_attention_bias = tf.Print(decoder_self_attention_bias, [tf.shape(decoder_self_attention_bias)], # summarize=1000, message="decoder_self_attention_bias") return (decoder_input, decoder_self_attention_bias)
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Run the basic autoencoder part first. basic_result, losses = super(AutoencoderAutoregressive, self).body(features) if hparams.autoregressive_mode == "none": assert not hparams.autoregressive_forget_base return basic_result, losses shape = common_layers.shape_list(basic_result) basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]]) # During autoregressive inference, don't resample. if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hasattr(hparams, "sampled_basic1d_tensor"): basic1d = hparams.sampled_basic1d_tensor else: hparams.sampled_basic1d_tensor = basic1d # Prepare inputs for autoregressive modes. if common_layers.shape_list(features["targets"])[1] == 1: # This happens on the first step of predicitions. assert hparams.mode == tf.estimator.ModeKeys.PREDICT features["targets"] = tf.zeros_like(basic_result) targets_dropout = common_layers.mix( features["targets"], tf.zeros_like(basic_result), hparams.bottleneck_warmup_steps, is_training, max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True) # Sometimes it's useful to look at non-autoregressive evals. if (hparams.mode == tf.estimator.ModeKeys.EVAL and hparams.autoregressive_eval_pure_autoencoder): targets_dropout = tf.zeros_like(basic_result) # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) # The forget_base hparam sets purely-autoregressive mode, no autoencoder. if hparams.autoregressive_forget_base: concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]]) concat1d = common_layers.shift_right_3d(concat1d) # The autoregressive part depends on the mode. if hparams.autoregressive_mode == "conv3": res = common_layers.conv1d( concat1d, shape[3], 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv3") return tf.reshape(res, shape), losses if hparams.autoregressive_mode == "conv5": res = common_layers.conv1d( concat1d, shape[3], 5, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv5") return tf.reshape(res, shape), losses if hparams.autoregressive_mode == "sru": res = common_layers.conv1d( concat1d, shape[3], 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_sru_conv3") res = common_layers.sru(res) return tf.reshape(res, shape), losses raise ValueError( "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
def _build_decoder_agreement_loss(self, central_lang_tag="<en>"): """Builds an agreement loss that enforces consistency of the decodings. Args: central_lang_tag: A string with the tag of the central language. A ``central'' language (usually English) is the one that has parallel data with all other languages. It is used to protect supervised directions from gradients coming from auxiliary losses. Returns: loss: <float32> [] for the agreement losses. """ # Get target embeddigns and vocab size. target_modality = self._problem_hparams.modality["targets"] target_modality_scope = self._variable_scopes[target_modality.name] target_embeddings = model_utils.get_embeddings( modality=target_modality, outer_scope=target_modality_scope, inner_scope="shared") target_vocab_size = target_modality._vocab_size # pylint: disable=protected-access # Build auxiliary sequences (if necessary). aux_keys = self._build_aux_sequences(target_embeddings, target_vocab_size, central_lang_tag=central_lang_tag) # Build loss. aux_loss = 0. with tf.name_scope("dec_agreement_loss"): for key1, key2 in zip(aux_keys, aux_keys[::-1]): # Prepare for decoding. targets = self.dec_outputs[key2]["rnn_output"] targets_length = self.dec_outputs[key2]["length"] shifted_targets = common_layers.shift_right_3d(targets) hiddens = self.enc_outputs[key1].outputs hiddens_length = self.inputs[key1][1] enc_state = self.enc_outputs[key1].final_state # Decode. decode_func = self.get_decode_func( target_embeddings, shifted_targets, targets_length, hiddens, hiddens_length, enc_state, mode=tf.estimator.ModeKeys.PREDICT, decoder_iterations=self._hparams.aux_decode_length) aux_dec_outputs = decode_func() # Compute logits (protect central directions from the gradients). aux_logits_1 = model_utils.build_logits( sequences=tf.expand_dims(aux_dec_outputs["rnn_output"], axis=2), embeddings=target_embeddings, vocab_size=target_vocab_size) aux_logits_1 = tf.where(self._is_central[key1], tf.stop_gradient(aux_logits_1), aux_logits_1) # Compute KL loss. logits = tf.squeeze(aux_logits_1, axis=2) if self._hparams.dec_agreement_loss_sparse: target_ids = self.dec_outputs[key2]["sample_id"] aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=True)( logits, target_ids, targets_length) else: aux_logits_2 = tf.squeeze(self.dec_outputs[key2]["logits"], axis=2) target_probs = tf.nn.softmax(aux_logits_2, axis=-1) aux_loss = aux_loss + losses.CrossEntropyLoss( sparse=False)(logits, target_probs, targets_length) aux_loss = self._hparams.dec_agreement_coeff * aux_loss return aux_loss
def body(self, features): hparams = self.hparams # Run the basic autoencoder part first. basic_result, losses = super(AutoencoderAutoregressive, self).body(features) if hparams.autoregressive_mode == "none": assert not hparams.autoregressive_forget_base return basic_result, losses if "training" in losses: plain_training_loss = losses.pop("training") losses["plain"] = plain_training_loss res_shape = common_layers.shape_list(basic_result) vocab_size = self._problem_hparams.vocab_size["targets"] if hasattr(self._hparams, "vocab_divisor"): vocab_size += (-vocab_size) % self._hparams.vocab_divisor targets = tf.one_hot(features["targets_raw"], vocab_size) # Prepare inputs for autoregressive modes. if common_layers.shape_list(features["targets"])[1] == 1: # This happens on the first step of predicitions. assert hparams.mode == tf.estimator.ModeKeys.PREDICT targets = tf.zeros_like(basic_result) targets = self.embed(targets) if hparams.autoregressive_gumbel_sample: basic_hot = self.gumbel_sample(basic_result) else: basic_hot = basic_result basic_result = self.embed(basic_hot) shape = common_layers.shape_list(basic_result) basic1d = tf.reshape(basic_result, [shape[0], -1, shape[-1]]) targets = tf.reshape(targets, common_layers.shape_list(basic_result)) # During autoregressive inference, don't resample. if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hasattr(hparams, "sampled_basic1d_tensor"): basic1d = hparams.sampled_basic1d_tensor else: hparams.sampled_basic1d_tensor = basic1d # Sometimes it's useful to look at non-autoregressive evals. targets_dropout = targets if (hparams.mode == tf.estimator.ModeKeys.EVAL and hparams.autoregressive_eval_pure_autoencoder): targets_dropout = tf.zeros_like(basic_result) # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[-1]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) # The forget_base hparam sets purely-autoregressive mode, no autoencoder. if hparams.autoregressive_forget_base: concat1d = tf.reshape(targets, [shape[0], -1, shape[-1]]) concat1d = common_layers.shift_right_3d(concat1d) # The autoregressive part depends on the mode. if hparams.autoregressive_mode == "conv3": res = common_layers.conv1d( concat1d, hparams.hidden_size, 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv3") res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses if hparams.autoregressive_mode == "conv5": res = common_layers.conv1d( concat1d, hparams.hidden_size, 5, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv5") res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses if hparams.autoregressive_mode == "sru": res = common_layers.conv1d( concat1d, hparams.hidden_size, 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_sru_conv3") res = common_layers.sru(res) res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses raise ValueError( "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
def body(self, features): hparams = self.hparams # Run the basic autoencoder part first. basic_result, losses = super(AutoencoderAutoregressive, self).body(features) if hparams.autoregressive_mode == "none": assert not hparams.autoregressive_forget_base return basic_result, losses if "training" in losses: plain_training_loss = losses.pop("training") losses["plain"] = plain_training_loss res_shape = common_layers.shape_list(basic_result) vocab_size = self._problem_hparams.modality["targets"].top_dimensionality targets = tf.one_hot(features["targets_raw"], vocab_size) # Prepare inputs for autoregressive modes. if common_layers.shape_list(features["targets"])[1] == 1: # This happens on the first step of predicitions. assert hparams.mode == tf.estimator.ModeKeys.PREDICT targets = tf.zeros_like(basic_result) targets = self.embed(targets) if hparams.autoregressive_gumbel_sample: basic_hot = self.gumbel_sample(basic_result) else: basic_hot = basic_result basic_result = self.embed(basic_hot) shape = common_layers.shape_list(basic_result) basic1d = tf.reshape(basic_result, [shape[0], -1, shape[-1]]) targets = tf.reshape(targets, common_layers.shape_list(basic_result)) # During autoregressive inference, don't resample. if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hasattr(hparams, "sampled_basic1d_tensor"): basic1d = hparams.sampled_basic1d_tensor else: hparams.sampled_basic1d_tensor = basic1d # Sometimes it's useful to look at non-autoregressive evals. targets_dropout = targets if (hparams.mode == tf.estimator.ModeKeys.EVAL and hparams.autoregressive_eval_pure_autoencoder): targets_dropout = tf.zeros_like(basic_result) # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[-1]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) # The forget_base hparam sets purely-autoregressive mode, no autoencoder. if hparams.autoregressive_forget_base: concat1d = tf.reshape(targets, [shape[0], -1, shape[-1]]) concat1d = common_layers.shift_right_3d(concat1d) # The autoregressive part depends on the mode. if hparams.autoregressive_mode == "conv3": res = common_layers.conv1d( concat1d, hparams.hidden_size, 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv3") res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses if hparams.autoregressive_mode == "conv5": res = common_layers.conv1d( concat1d, hparams.hidden_size, 5, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv5") res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses if hparams.autoregressive_mode == "sru": res = common_layers.conv1d( concat1d, hparams.hidden_size, 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_sru_conv3") res = common_layers.sru(res) res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses raise ValueError( "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)