def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def attention_lm_moe_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly baises for diagonal alignments pad_remover (expert_utils.PadRemover): an util object to remove padding """ targets_pad_mask = common_attention.embedding_to_padding(targets) with tf.name_scope("pad_remover"): # Because of the shift_right, the <eos> token will be considered as # padding. In practice, it doesn't really matter, due to the triangular # mask, this token should never be attended. pad_remover = expert_utils.PadRemover(targets_pad_mask) if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepend_inputs_full_attention( targets_pad_mask)) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias, pad_remover)
def decode_inputs_to_outputs(self, decoder_embed_inputs, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step): if self.hparams.pos == 'timing': decoder_embed_inputs = common_attention.add_timing_signal_1d( decoder_embed_inputs) print('Use positional encoding in decoder text.') decoder_attn_bias = common_attention.attention_bias_lower_triangle( tf.shape(decoder_embed_inputs)[1]) decoder_embed_inputs = tf.nn.dropout( decoder_embed_inputs, 1.0 - self.hparams.layer_prepostprocess_dropout) if 'rule' in self.model_config.memory: decoder_output, contexts = transformer.transformer_decoder2( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, self.hparams) # encoder_gate_w = tf.get_variable('encoder_gate_w', shape=( # 1, self.model_config.dimension, 1)) # encoder_gate_b = tf.get_variable('encoder_gate_b', shape=(1, 1, 1)) # encoder_gate = tf.tanh(encoder_gate_b + tf.nn.conv1d(encoder_outputs, encoder_gate_w, 1, 'SAME')) # encoder_context_outputs = tf.expand_dims(tf.reduce_mean(encoder_outputs * encoder_gate, axis=1), axis=1) cur_context = contexts[0] #tf.concat(contexts, axis=-1) cur_mem_contexts = tf.stack(self.embedding_fn( rule_id_input_placeholder, mem_contexts), axis=1) cur_mem_outputs = tf.stack(self.embedding_fn( rule_id_input_placeholder, mem_outputs), axis=1) bias = tf.expand_dims(-1e9 * tf.to_float( tf.equal(tf.stack(rule_id_input_placeholder, axis=1), 0)), axis=1) weights = tf.nn.softmax( bias + tf.matmul(cur_context, cur_mem_contexts, transpose_b=True)) mem_output = tf.matmul(weights, cur_mem_outputs) temp_output = tf.concat((decoder_output, mem_output), axis=-1) w = tf.get_variable('w_ffn', shape=(1, self.model_config.dimension * 2, self.model_config.dimension)) # b = tf.get_variable('b_ffn', shape=( # 1, 1, self.model_config.dimension)) mem_output = tf.nn.conv1d(temp_output, w, 1, 'SAME') g = tf.greater( global_step, tf.constant(2 * self.model_config.memory_prepare_step, dtype=tf.int64)) final_output = tf.cond(g, lambda: mem_output, lambda: decoder_output) return final_output, decoder_output, cur_context else: decoder_output = transformer.transformer_decoder( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, self.hparams) final_output = decoder_output return final_output, decoder_output, None
def _apply_decoder_layer(translation_layer, input_tensor, output_depth, encoder_depth): """Applies an decoder layer with basic arguments.""" residual_tensor_values = np.random.rand( *[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, output_depth]) - .5 residual_tensor = tf.constant(residual_tensor_values, dtype=tf.float32) encoder_output_values = np.random.rand( *[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, encoder_depth]) - .5 encoder_output = tf.constant(encoder_output_values, dtype=tf.float32) encoder_block_outputs = [encoder_output] * _NUM_BLOCKS hparams = transformer.transformer_base() hparams.attention_dropout = 0 decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(_TOTAL_SEQUENCE_LENGTH)) output_tensor = translation_layer.apply_layer( input_tensor, residual_tensor, output_depth, None, hparams, "", nonpadding=None, mask_future=True, layer_preprocess_fn=None, postprocess_dropout=False, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=None, encoder_block_outputs=encoder_block_outputs, block_number=_BLOCK_NUMBER) return output_tensor
def attention_lm_moe_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly biases for diagonal alignments pad_remover (expert_utils.PadRemover): an util object to remove padding """ targets_pad_mask = common_attention.embedding_to_padding(targets) with tf.name_scope("pad_remover"): # Because of the shift_right, the <eos> token will be considered as # padding. In practice, it doesn't really matter, due to the triangular # mask, this token should never be attended. pad_remover = expert_utils.PadRemover(targets_pad_mask) if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepend_inputs_full_attention( targets_pad_mask)) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( tf.shape(targets)[1])) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias, pad_remover)
def test_nas_decoder_resizing_output(self): hparams, wrong_size = self._get_wrong_output_dim_decoder_hparams() hparams.enforce_output_size = False input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH]) decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(_INPUT_LENGTH)) with tf.variable_scope("wrong"): wrong_size_decoder_output = translation_nas_net.nas_decoder( decoder_input=input_tensor, encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=None, hparams=hparams) # Now add the correction. hparams.enforce_output_size = True with tf.variable_scope("correct"): correct_size_decoder_output = translation_nas_net.nas_decoder( decoder_input=input_tensor, encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=None, hparams=hparams) with self.test_session() as session: session.run(tf.global_variables_initializer()) wrong_output, correct_output = session.run( [wrong_size_decoder_output, correct_size_decoder_output]) self.assertEqual(wrong_output.shape, (_BATCH_SIZE, _INPUT_LENGTH, wrong_size)) self.assertEqual(correct_output.shape, (_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH))
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def attention_lm_moe_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly baises for diagonal alignments pad_remover (expert_utils.PadRemover): an util object to remove padding """ targets_pad_mask = common_attention.embedding_to_padding(targets) with tf.name_scope("pad_remover"): pad_remover = expert_utils.PadRemover(targets_pad_mask) if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepended(targets_pad_mask)) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( tf.shape(targets)[1])) decoder_input = common_layers.shift_left_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias, pad_remover)
def prepare_decoder(targets, hparams): """Prepare decoder for images.""" targets_shape = common_layers.shape_list(targets) channels = hparams.num_channels curr_infer_length = None # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. # At inference, they are [batch, curr_infer_length, 1, 1] if hparams.mode == tf.contrib.learn.ModeKeys.INFER: curr_infer_length = targets_shape[1] if hparams.block_rastor_scan: assert hparams.img_len*channels % hparams.query_shape[1] == 0 assert hparams.img_len % hparams.query_shape[0] == 0 total_block_width = hparams.img_len*channels # Decoding is in block rastor scan order. We divide the image into # hparams.query_shape blocks and then decode each block in rastor scan. # To make that compatible with our inference pipeline, pad the target so # that rows is a multiple of query_shape and columns is a multiple of # hparams.img_len*channels curr_infer_length = targets_shape[1] block_padding_factor = total_block_width * hparams.query_shape[0] targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % block_padding_factor], [0, 0], [0, 0]]) num_blocks = total_block_width // hparams.query_shape[1] # Reshape the image to represent blocks target_blocks = tf.reshape( targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0], hparams.query_shape[1]]) # Transpose to read the image in 2D fashion. targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4]) else: # add padding to make sure the size of targets is a multiple of img_height # times number of channels. This is needed for positional encodings and # for doing the RGB lookup. padding_factor = channels * hparams.img_len targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]]) targets = tf.reshape(targets, [targets_shape[0], -1, hparams.img_len, channels]) # Preprocess image x = prepare_image(targets, hparams, name="dec_channels") x_shape = common_layers.shape_list(x) # mask out upper triangle to avoid looking into the future. bias = common_attention.attention_bias_lower_triangle(x_shape[1]*x_shape[2]) if hparams.dec_attention_type == AttentionType.LOCAL_2D: x = common_attention.right_shift_blockwise(x, hparams.query_shape) x = add_pos_signals(x, hparams, "dec_pos") else: # Add position signals x = tf.reshape(x, [targets_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) x = common_layers.shift_right_3d(x) x = tf.reshape(x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size]) x = add_pos_signals(x, hparams, "dec_pos") return x, x_shape[1], x_shape[2], bias
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in decoder self-attention """ if hparams.causal_decoder_self_attention: # Causal attention. if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepend_inputs_full_attention( common_attention.embedding_to_padding(targets))) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) else: # Full attention. decoder_padding = common_attention.embedding_to_padding(targets) decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(decoder_padding)) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d(decoder_input) elif hparams.pos == "emb": decoder_input = common_attention.add_positional_embedding( decoder_input, hparams.max_length, "targets_positional_embedding", targets_position) if hparams.activation_dtype == "bfloat16": decoder_self_attention_bias = tf.cast(decoder_self_attention_bias, tf.bfloat16) return (decoder_input, decoder_self_attention_bias)
def prepare_decoder(targets, target_space_emb): """Prepare decoder.""" decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) target_space_emb = tf.reshape(target_space_emb, [1, 1, -1]) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1]) decoder_input = common_layers.shift_right_3d(targets, pad_value=target_space_emb) decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def prepare_decoder(targets, target_space_emb): """Prepare decoder.""" decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) target_space_emb = tf.reshape(target_space_emb, [1, 1, -1]) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1]) decoder_input = common_layers.shift_right_3d( targets, pad_value=target_space_emb) decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) #if hparams.pos == "timing": # if targets_position is not None: # decoder_input = common_attention.add_timing_signal_1d_given_position( # decoder_input, targets_position) # else: # decoder_input = common_attention.add_timing_signal_1d(decoder_input) raw_decoder_input = common_layers.shift_right(features['targets_raw']) terminal_decoder_bias, nonterminal_decoder_bias = _get_t_nt_bias( raw_decoder_input, hparams, decoder_self_attention_bias) pop_decoder_bias = _get_pop_bias(raw_decoder_input, hparams) raw_decoder_input = tf.squeeze(raw_decoder_input, axis=[-2, -1]) pos_signals = generate_positional_signals(raw_decoder_input, hparams, terminal_decoder_bias, nonterminal_decoder_bias) pos_embeddings = generate_positional_embeddings(pos_signals, hparams.decoder_pos, hparams) if "sum" in hparams.decoder_pos_integration: decoder_input = decoder_input + pos_embeddings elif "ffn" in hparams.decoder_pos_integration: with tf.variable_scope("decoder_pos_ffn"): decoder_input = tf.concat([decoder_input, pos_embeddings], axis=2) decoder_input = transformer_ffn_layer(decoder_input, hparams, conv_padding="LEFT") return (decoder_input, decoder_self_attention_bias, terminal_decoder_bias, nonterminal_decoder_bias, pop_decoder_bias, pos_signals)
def transformer_prepare_decoder(targets, hparams): """Copied from tensor2tensor.models.transformer.""" decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def testMultiheadSelfAttentionMemoryEfficient(self): if tf.executing_eagerly(): return # don't run test in Eager mode num_heads = 4 io_size = 16 batch = 2 length = 7 head_size = 5 x = np.random.rand(batch, length, io_size) dy = np.random.rand(batch, length, io_size) with self.session() as session: x = tf.to_float(x) dy = tf.to_float(dy) bias = common_attention.attention_bias_lower_triangle(length) wqkv = tf.get_variable( "wqkv", [num_heads, 1, io_size, 3 * head_size], initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) wo = tf.get_variable("wo", [num_heads, 1, head_size, io_size], initializer=tf.random_normal_initializer( stddev=(head_size * num_heads)**-0.5)) norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) y = common_attention.multihead_self_attention_memory_efficient( x, bias, num_heads, head_size=head_size, forget=False, test_vars=(wqkv, wo, norm_scale, norm_bias)) y_forget = common_attention.multihead_self_attention_memory_efficient( x, bias, num_heads, head_size=head_size, forget=True, test_vars=(wqkv, wo, norm_scale, norm_bias)) dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients( ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients( ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) session.run(tf.global_variables_initializer()) (y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run([ y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f ]) self.assertAllClose(y, y_forget) self.assertAllClose(dwo, dwo_f) self.assertAllClose(dwqkv, dwqkv_f) self.assertAllClose(dnorm_scale, dnorm_scale_f) self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f)
def decode_inputs_to_outputs(self, kword_input, abstr_outputs, abstr_bias, hist_vector=None): if self.hparams.pos == 'timing': kword_input = common_attention.add_timing_signal_1d(kword_input) kword_tribias = common_attention.attention_bias_lower_triangle(tf.shape(kword_input)[1]) kword_input = tf.nn.dropout( kword_input, 1.0 - self.hparams.layer_prepostprocess_dropout) kword_output = transformer.transformer_decoder( kword_input, abstr_outputs, kword_tribias, abstr_bias, self.hparams, hist_vector=hist_vector) return kword_output
def decode(cond_vec, cond_add, gold, c, ed, hparams): """Transformer decoder.""" drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout) decoder_input = common_layers.shift_right(drop_gold, pad_value=cond_vec) if cond_add is not None: decoder_input += cond_add decoder_input = tf.squeeze(decoder_input, axis=2) decoder_input = common_attention.add_timing_signal_1d(decoder_input) bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1]) if c is not None and len(c.get_shape()) > 3: c = tf.squeeze(c, axis=2) return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams)
def get_self_attention_bias(x): """Creates masked self attention bias. Args: x: A tensor of shape [batch, length, depth] Returns: self_attention_bias: A tensor of shape [length, length, 1] """ x_shape = common_layers.shape_list(x) self_attention_bias = common_attention.attention_bias_lower_triangle( x_shape[1]) return self_attention_bias
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.model_d, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias_lower_triangle( target_shape[1]) inputs_encoded = common_layers.flatten4d3d(inputs_encoded) # TODO(jbaccash): use input bias parameter. This code seems to assume fixed # size inputs. inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) qv = common_attention.multihead_attention( targets_timed, None, target_attention_bias, hparams.model_d, hparams.model_d, hparams.model_d, hparams.num_heads, hparams.attention_dropout, name="self_attention") qv = common_attention.multihead_attention( qv, inputs_encoded, inputs_attention_bias, hparams.model_d, hparams.model_d, hparams.model_d, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") return tf.expand_dims(qv, 2) else: raise ValueError("Unsupported attention_type: %s" % hparams.attention_type)
def get_self_attention_bias(x): """Creates masked self attention bias. Args: x: A tensor of shape [batch, length, depth] Returns: self_attention_bias: A tensor of shape [length, length, 1] """ x_shape = common_layers.shape_list(x) self_attention_bias = common_attention.attention_bias_lower_triangle( x_shape[1]) return self_attention_bias
def transformer_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) return (decoder_input, decoder_self_attention_bias)
def decode_syntax_template(self, trg_syntax_emb): with tf.variable_scope('syntax_decoder', reuse=tf.AUTO_REUSE): trg_syntax_emb = common_attention.add_timing_signal_1d( trg_syntax_emb) trg_syntax_emb = self.update_embedding(trg_syntax_emb) trg_syntax_length = tf.shape(trg_syntax_emb)[1] trg_self_attention_bias = common_attention.attention_bias_lower_triangle( trg_syntax_length) trg_syntax_outputs = transformer.transformer_decoder( decoder_input=trg_syntax_emb, decoder_self_attention_bias=trg_self_attention_bias, encoder_output=self.shared_tensors['src_outputs'], encoder_decoder_attention_bias=self.shared_tensors['src_bias'], hparams=self.hparams, external_output=self. shared_tensors['template_prev_simp_outputs'], external_bias=self.shared_tensors['template_simp_bias']) return trg_syntax_outputs
def attention_lm_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly baises for diagonal alignments """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) decoder_input = common_layers.shift_left_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def test_calculate_branching_model_parameters_decoder_resize( self, enforce_output_size): tf.reset_default_graph() hparams, _ = self._get_wrong_output_dim_decoder_hparams() hparams.enforce_output_size = enforce_output_size hparams.decoder_left_norms = [translation_nas_net.NO_NORM_KEY] * 5 hparams.decoder_right_norms = [translation_nas_net.NO_NORM_KEY] * 5 # Get predicted number of parameters. (predicted_num_params, _, _, _) = translation_nas_net.calculate_branching_model_parameters( encoding_depth=_EMBEDDING_DEPTH, left_inputs=hparams.decoder_left_inputs, left_layers=hparams.decoder_left_layers, left_output_dims=hparams.decoder_left_output_dims, right_inputs=hparams.decoder_right_inputs, right_layers=hparams.decoder_right_layers, right_output_dims=hparams.decoder_right_output_dims, combiner_functions=hparams.decoder_combiner_functions, final_combiner_function=hparams.decoder_final_combiner_function, layer_registry=layers.DECODER_LAYERS, num_cells=hparams.decoder_num_cells, encoder_depth=_EMBEDDING_DEPTH, enforce_output_size=enforce_output_size) # Count graph variables. input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH]) decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(_INPUT_LENGTH)) _ = translation_nas_net.nas_decoder( decoder_input=input_tensor, encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=None, hparams=hparams, final_layer_norm=False) trainable_variables_list = tf.trainable_variables() empirical_num_params = 0 for variable_tensor in trainable_variables_list: empirical_num_params += _list_product( variable_tensor.shape.as_list()) self.assertEqual(empirical_num_params, predicted_num_params)
def testMultiheadSelfAttentionMemoryEfficient(self): num_heads = 4 io_size = 16 batch = 2 length = 7 head_size = 5 x = np.random.rand(batch, length, io_size) dy = np.random.rand(batch, length, io_size) with self.test_session() as session: x = tf.to_float(x) dy = tf.to_float(dy) bias = common_attention.attention_bias_lower_triangle(length) wqkv = tf.get_variable( "wqkv", [num_heads, 1, io_size, 3 * head_size], initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) wo = tf.get_variable( "wo", [num_heads, 1, head_size, io_size], initializer=tf.random_normal_initializer( stddev=(head_size * num_heads)**-0.5)) norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) y = common_attention.multihead_self_attention_memory_efficient( x, bias, num_heads, head_size=head_size, forget=False, test_vars=(wqkv, wo, norm_scale, norm_bias)) y_forget = common_attention.multihead_self_attention_memory_efficient( x, bias, num_heads, head_size=head_size, forget=True, test_vars=(wqkv, wo, norm_scale, norm_bias)) dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients( ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients( ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy]) session.run(tf.global_variables_initializer()) (y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run( [y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f]) self.assertAllClose(y, y_forget) self.assertAllClose(dwo, dwo_f) self.assertAllClose(dwqkv, dwqkv_f) self.assertAllClose(dnorm_scale, dnorm_scale_f) self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f)
def transformer_prepare_decoder(targets_emb_var, targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets_emb_var: a Tensor targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in decoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None decoder_input = tf.gather(targets_emb_var, common_layers.shift_right_2d(targets)) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): decoder_input = common_attention.add_positional_embedding( decoder_input, hparams.max_length, "positional_embedding", targets_position) if hparams.activation_dtype == "bfloat16": decoder_self_attention_bias = tf.cast(decoder_self_attention_bias, tf.bfloat16) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) # decoder_input = tf.Print(decoder_input, [tf.shape(decoder_input)], # summarize=1000, message="decoder_input") # decoder_self_attention_bias = tf.Print(decoder_self_attention_bias, [tf.shape(decoder_self_attention_bias)], # summarize=1000, message="decoder_self_attention_bias") return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = (comm_attn.attention_bias_lower_triangle( tf.shape(targets)[1])) if hparams.proximity_bias: decoder_self_attention_bias += comm_attn.attention_bias_proximal( tf.shape(targets)[1]) decoder_input = common_layers.shift_left_3d(targets) if hparams.pos == 'timing': decoder_input = comm_attn.add_timing_signal_1d(decoder_input) # Putting this here since always called immediately after... decoder_input = with_dropout(decoder_input, hparams) return DecoderState(input=decoder_input, self_attn_bias=decoder_self_attention_bias)
def attention_lm_prepare_decoder(targets, hparams): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a Tensor, containing large negative values to implement masked attention and possibly baises for diagonal alignments """ if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepended( common_attention.embedding_to_padding(targets))) else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] if "targets_segmentation" in features: raise NotImplementedError( "Decoding not supported on packed datasets " " If you want to decode from a dataset, use the non-packed version" " of the dataset when decoding.") if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) contexts = {} for feature_name in features: if 'context' in feature_name and 'raw' not in feature_name: contexts[feature_name] = features[feature_name] inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] context_modality = {} for context_name in contexts: if context_name in self._problem_hparams.modality: context_modality[ context_name] = self._problem_hparams.modality[ context_name] else: context_modality[context_name] = input_modality with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE): inputs = input_modality.bottom_sharded(inputs, dp) for feature_name in contexts: with tf.variable_scope(context_modality[feature_name].name, reuse=tf.AUTO_REUSE): contexts[feature_name] = context_modality[ feature_name].bottom_sharded(contexts[feature_name], dp) contexts_list = [ contexts[feature_name][0] for feature_name in contexts ] contexts = tf.concat(contexts_list, axis=1) inputs = [tf.concat([contexts, inputs[0]], axis=1)] with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. encoder_output = None encoder_decoder_attention_bias = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) partial_targets_length = partial_targets_shape[1] decode_length = (partial_targets_length + features.get("decode_length", decode_length)) batch_size = partial_targets_shape[0] if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length, "body/targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:] else: ret["outputs"] = ret["outputs"][:, :, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = common_layers.shape_list(inputs)[0] target_modality = self._problem_hparams.target_modality if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, nonpadding=_features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access if not context.in_eager_mode(): for layer in cache: cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims( common_layers.sample_with_temperature(logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def _bias(x): return common_attention.attention_bias_lower_triangle( tf.shape(x)[1])
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """ Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = common_layers.shape_list(inputs)[0] target_modality = self._problem_hparams.target_modality if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, nonpadding=transformer._features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access if not context.in_eager_mode(): for layer in cache: cache[layer]["k"]._shape = tf.TensorShape( [None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape( [None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, stop_early=(top_beams == 1)) decoded_ids = decoded_ids[:, :, 1:] # do roulette wheel selection or inverse roulette wheel selection if self._hparams.roulette == "Normal" or self._hparams.roulette == "Inverse": if self._hparams.roulette == "Normal": probabilities = tf.pow(tf.constant(2.0), scores) start = 0 else: probabilities = tf.subtract( tf.constant(1.0), tf.pow(tf.constant(2.0), scores)) start = beam_size - self._hparams.roulette_beam_size summ = tf.reduce_sum(probabilities) ex_probs = tf.divide(probabilities, summ) #ex_probs=tf.nn.softmax(probabilities) # sample a number between 0 and 1 wheel = tf.random_uniform([1]) upper_bound = tf.constant(0.0) # change this as well if using inverse for i in range(start, self._hparams.roulette_beam_size): upper_bound = tf.add(ex_probs[:, i], upper_bound) truthValue = tf.squeeze( tf.logical_and(wheel >= upper_bound - ex_probs[:, i], wheel <= upper_bound)) decoded_ids, scores, i = tf.cond( truthValue, lambda: (decoded_ids[:, i, :], scores[:, i], beam_size), lambda: (decoded_ids, scores, i)) else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims(common_layers.sample_with_temperature( logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def decode_inputs_to_outputs(self, decoder_embed_inputs, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, obj_tensors=None): if self.hparams.pos == 'timing': decoder_embed_inputs = common_attention.add_timing_signal_1d(decoder_embed_inputs) print('Use positional encoding in decoder text.') decoder_embed_inputs = self.update_decoder_embedding(decoder_embed_inputs, score, self.model_config.beam_search_size) decoder_attn_bias = common_attention.attention_bias_lower_triangle(tf.shape(decoder_embed_inputs)[1]) decoder_embed_inputs = tf.nn.dropout(decoder_embed_inputs, 1.0 - self.hparams.layer_prepostprocess_dropout) if 'direct' in self.model_config.memory: assert 'direct_bert_output' in obj_tensors decoder_output = transformer.transformer_multi_decoder( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, obj_tensors['direct_bert_output'], obj_tensors['direct_bert_bias'], self.hparams, save_weights_to=obj_tensors, direct_mode=self.model_config.direct_mode) if self.model_config.npad_mode == 'static_seq': decoder_output = tf.nn.conv1d(decoder_output, obj_tensors['npad_w'], 1, 'SAME') return decoder_output, decoder_output, None elif 'rule' in self.model_config.memory: decoder_output, contexts = transformer.transformer_decoder_contexts( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, self.hparams) # encoder_gate_w = tf.get_variable('encoder_gate_w', shape=( # 1, self.model_config.dimension, 1)) # encoder_gate_b = tf.get_variable('encoder_gate_b', shape=(1, 1, 1)) # encoder_gate = tf.tanh(encoder_gate_b + tf.nn.conv1d(encoder_outputs, encoder_gate_w, 1, 'SAME')) # encoder_context_outputs = tf.expand_dims(tf.reduce_mean(encoder_outputs * encoder_gate, axis=1), axis=1) cur_context = contexts[0] #tf.concat(contexts, axis=-1) cur_mem_contexts = tf.stack(self.embedding_fn(rule_id_input_placeholder, mem_contexts), axis=1) cur_mem_outputs = tf.stack(self.embedding_fn(rule_id_input_placeholder, mem_outputs), axis=1) cur_mem_contexts = tf.reshape(cur_mem_contexts, [self.model_config.batch_size, self.model_config.max_target_rule_sublen*self.model_config.max_cand_rules, self.model_config.dimension]) cur_mem_outputs = tf.reshape(cur_mem_outputs, [self.model_config.batch_size, self.model_config.max_target_rule_sublen*self.model_config.max_cand_rules, self.model_config.dimension]) # bias = tf.expand_dims( # -1e9 * tf.to_float(tf.equal(tf.stack(rule_id_input_placeholder, axis=1), 0)), # axis=1) # weights = tf.nn.softmax(bias + tf.matmul(cur_context, cur_mem_contexts, transpose_b=True)) weights = tf.nn.softmax(tf.matmul(cur_context, cur_mem_contexts, transpose_b=True)) mem_output = tf.matmul(weights, cur_mem_outputs) # trainable_mem = 'stopgrad' not in self.model_config.rl_configs temp_output = tf.concat((decoder_output, mem_output), axis=-1) # w_u = tf.get_variable('w_ffn', shape=( # 1, self.model_config.dimension*2, self.model_config.dimension), trainable=trainable_mem) # b_u = tf.get_variable('b_ffn', shape=( # 1, 1, self.model_config.dimension), trainable=trainable_mem) # w_u.reuse_variables() # b_u.reuse_variables() # tf.get_variable_scope().reuse_variables() w_t = tf.get_variable('w_ffn', shape=( 1, self.model_config.dimension*2, self.model_config.dimension), trainable=True) b_t = tf.get_variable('b_ffn', shape=( 1, 1, self.model_config.dimension), trainable=True) # w = tf.cond(tf.equal(tf.mod(self.global_step, 2), 0), lambda: w_t, lambda: w_u) # b = tf.cond(tf.equal(tf.mod(self.global_step, 2), 0), lambda: b_t, lambda: b_u) mem_output = tf.nn.conv1d(temp_output, w_t, 1, 'SAME') + b_t g = tf.greater(global_step, tf.constant(self.model_config.memory_prepare_step, dtype=tf.int64)) final_output = tf.cond(g, lambda: mem_output, lambda: decoder_output) return final_output, decoder_output, cur_context else: if self.model_config.architecture == 'ut2t': (decoder_output, decoder_extra_output) = universal_transformer_util.universal_transformer_decoder( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, self.hparams, save_weights_to=obj_tensors) dec_ponder_times, dec_remainders = decoder_extra_output extra_dec_loss = ( self.hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) if self.is_train: obj_tensors['extra_decoder_loss'] = extra_dec_loss else: decoder_output = transformer.transformer_decoder( decoder_embed_inputs, encoder_outputs, decoder_attn_bias, encoder_attn_bias, self.hparams, save_weights_to=obj_tensors, npad_mode=self.model_config.npad_mode) final_output = decoder_output return final_output, decoder_output, None
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. # In this case, features["inputs"] contains partial targets. # We force the outputs to begin with these sequences. encoder_output = None encoder_decoder_attention_bias = None partial_targets = tf.squeeze(tf.to_int64(features["inputs"]), [2, 3]) partial_targets_length = common_layers.shape_list(partial_targets)[1] decode_length += partial_targets_length batch_size = tf.shape(partial_targets)[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot(tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size) if partial_targets is not None: ret["outputs"] = ret["outputs"][:, partial_targets_length:] return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0, sentence_cache=None): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.target_modality if self.has_input: inputs = features["inputs"] if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list( inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None else: # The problem has no inputs. # In this case, features["inputs"] contains partial targets. # We force the outputs to begin with these sequences. encoder_output = None encoder_decoder_attention_bias = None partial_targets = tf.squeeze(tf.to_int64(features["inputs"]), [2, 3]) partial_targets_length = common_layers.shape_list( partial_targets)[1] decode_length += partial_targets_length batch_size = tf.shape(partial_targets)[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond(tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache, body_outputs ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, sentence_cache=self.sentence_cache, cache_flag=self.cache_flag) if partial_targets is not None: ret["outputs"] = ret["outputs"][:, partial_targets_length:] return ret
def test_calculate_branching_model_parameters_transformer( self, get_config, expected_hidden_depths): tf.reset_default_graph() (num_cells, left_inputs, left_layers, left_output_dims, right_inputs, right_layers, right_output_dims, combiner_functions, final_combiner_function, dummy_activations, dummy_norms, layer_registry, is_decoder) = get_config() # Get predicted number of parameters. (predicted_num_params, output_size, hidden_depths, _) = translation_nas_net.calculate_branching_model_parameters( encoding_depth=_EMBEDDING_DEPTH, left_inputs=left_inputs, left_layers=left_layers, left_output_dims=left_output_dims, right_inputs=right_inputs, right_layers=right_layers, right_output_dims=right_output_dims, combiner_functions=combiner_functions, final_combiner_function=final_combiner_function, layer_registry=layer_registry, num_cells=num_cells, encoder_depth=_EMBEDDING_DEPTH) # Create model graph. input_tensor = tf.zeros([32, _INPUT_LENGTH, _EMBEDDING_DEPTH]) hparams = transformer.transformer_small() if is_decoder: nonpadding = None mask_future = True decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(_INPUT_LENGTH)) encoder_cell_outputs = [input_tensor] * 6 else: nonpadding = tf.ones([32, _INPUT_LENGTH]) mask_future = False decoder_self_attention_bias = None encoder_cell_outputs = None translation_nas_net.apply_nas_layers( input_tensor=input_tensor, left_inputs=left_inputs, left_layers=left_layers, left_activations=dummy_activations, left_output_dims=left_output_dims, left_norms=dummy_norms, right_inputs=right_inputs, right_layers=right_layers, right_activations=dummy_activations, right_output_dims=right_output_dims, right_norms=dummy_norms, combiner_functions=combiner_functions, final_combiner_function=final_combiner_function, num_cells=num_cells, nonpadding=nonpadding, layer_registry=layer_registry, mask_future=mask_future, hparams=hparams, var_scope="test", encoder_decoder_attention_bias=None, encoder_cell_outputs=encoder_cell_outputs, decoder_self_attention_bias=decoder_self_attention_bias, final_layer_norm=False) # Count graph variables. trainable_variables_list = tf.trainable_variables() empirical_num_params = 0 for variable_tensor in trainable_variables_list: empirical_num_params += _list_product( variable_tensor.shape.as_list()) # Compare. self.assertEqual(empirical_num_params, predicted_num_params) self.assertEqual(output_size, _EMBEDDING_DEPTH) self.assertEqual(hidden_depths, expected_hidden_depths)
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): #dp = self._data_parallelism hparams = self._hparams target_modality = self._problem_hparams.modality["targets"] inputs = features["inputs"] decode_length = (common_layers.shape_list(inputs)[1] + features.get("decode_length", decode_length)) #inputs = tf.expand_dims(inputs, axis=1) #if len(inputs.shape) < 5: # inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) batch_size = s[0] #inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match #inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.modality["inputs"] context_modality = {} contexts = {} for feature_name in features: if 'context' in feature_name and 'raw' not in feature_name: contexts[feature_name] = features[feature_name] for context_name in contexts: if context_name in self._problem_hparams.modality: context_modality[ context_name] = self._problem_hparams.modality[ context_name] else: context_modality[context_name] = input_modality with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE): inputs = input_modality.bottom(inputs) for context_name in contexts: contexts[context_name] = context_modality[context_name].bottom( contexts[context_name]) with tf.variable_scope("body", reuse=tf.AUTO_REUSE): encoder_output, encoder_decoder_attention_bias = self.encode( inputs, contexts, features["target_space_id"], hparams, features=features) #encoder_output = encoder_output[0] #encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] partial_targets = None if hparams.pos == "timing": positional_encoding = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) elif hparams.pos == "emb": positional_encoding = common_attention.add_positional_embedding( tf.zeros([1, decode_length + 1, hparams.hidden_size]), hparams.max_length, "targets_positional_embedding", None) else: positional_encoding = None def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match #targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom(targets) targets = common_layers.flatten4d3d(targets) targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if positional_encoding is not None: targets += positional_encoding[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = self.decode( targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top(body_outputs, None) ret = tf.squeeze(logits, axis=[1, 2, 3]) return ret, cache ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams, decode_length=decode_length, vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, alpha=alpha, batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length) return ret
def _bias(x): return common_attention.attention_bias_lower_triangle( common_layers.shape_list(x)[1])