def transformer_text_encoder(inputs, target_space, hparams, name=None): """Transformer text encoder over inputs with unmasked full attention. Args: inputs: Tensor of shape [batch, length, 1, hparams.hidden_size]. target_space: int. Used for encoding inputs under a target space id. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: encoder_output: Tensor of shape [batch, length, hparams.hidden_size]. ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias for any padded tokens. """ with tf.variable_scope(name, default_name="transformer_text_encoder"): inputs = common_layers.flatten4d3d(inputs) [ encoder_input, encoder_self_attention_bias, ed, ] = transformer.transformer_prepare_encoder(inputs, target_space=target_space, hparams=hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams) return encoder_output, ed
def transformer_encoder_ht(inputs, target_space, hparams, features=None, losses=None): encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder(inputs, target_space, hparams, features=features)) # encoder_input = tf.nn.dropout(encoder_input, # 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, self_attention_bias, hparams, # nonpadding=transformer.features_to_nonpadding(features, "inputs"), nonpadding=None, save_weights_to=None, losses=losses) # encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output
def transformer_text_encoder(x, space_id, hparams, name="transformer_text_encoder"): """Transformer text encoder over inputs with unmasked full attention. Args: x: Tensor of shape [batch, length, 1, hparams.hidden_size]. space_id: int, id. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: encoder_output: Tensor of shape [batch, length, hparams.hidden_size]. ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias for any padded tokens. """ with tf.variable_scope(name): x = common_layers.flatten4d3d(x) (encoder_input, encoder_self_attention_bias, ed) = transformer.transformer_prepare_encoder(x, space_id, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams) return encoder_output, ed
def encode(self, encoder_input, target_space, hparams): dir_path = os.path.dirname(os.path.realpath(__file__)) config_file = os.path.join(dir_path, "config.yml") config = yaml.load(open(config_file)) enc_name = config["model_params"].split('_')[0][3:] if enc_name == "simple": encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias = transformer.transformer_prepare_encoder( encoder_input, target_space, hparams) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams) else: encoder_input, encoder_self_attention_bias_slices, encoder_decoder_attention_bias_slices = parallel_transformer_prepare_encoder( encoder_input, target_space, hparams) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = getattr(encode_fn, enc_name)( encoder_input, encoder_self_attention_bias_slices, hparams, "encoder") encoder_decoder_attention_bias = tf.stack( encoder_decoder_attention_bias_slices) encoder_decoder_attention_bias = tf.reduce_mean( encoder_decoder_attention_bias, 0) return encoder_output, encoder_decoder_attention_bias
def encode(self, stories, questions, target_space, hparams, unused_features=None): """Encode transformer inputs. Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. unused_features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encodre-decoder attention. [batch_size, input_length] """ inputs = tf.concat([stories, questions], axis=1) # inputs = common_layers.flatten4d3d(inputs) (encoder_input, encoder_self_attention_bias, _) = ( transformer.transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, # nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights) return encoder_output
def transformer_text_encoder(x, space_id, hparams, name="transformer_text_encoder"): """Transformer text encoder over inputs with unmasked full attention. Args: x: Tensor of shape [batch, length, 1, hparams.hidden_size]. space_id: int, id. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: encoder_output: Tensor of shape [batch, length, hparams.hidden_size]. ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias for any padded tokens. """ with tf.variable_scope(name): x = common_layers.flatten4d3d(x) (encoder_input, encoder_self_attention_bias, ed) = transformer.transformer_prepare_encoder(x, space_id, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) return transformer.transformer_encoder(encoder_input, encoder_self_attention_bias, hparams), ed
def encode(x, x_space, hparams, name): """Transformer preparations and encoder.""" with tf.variable_scope(name): (encoder_input, encoder_self_attention_bias, ed) = transformer.transformer_prepare_encoder(x, x_space, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) return transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams), ed
def encode_syntax_template(self, template_embs, template_bias): with tf.variable_scope('syntax_encoder', reuse=tf.AUTO_REUSE): # template_mask = tf.cast( # tf.equal(template_ids[:, 0, :], self.data.vocab.pad_id), tf.float32) # template_bias = common_attention.attention_bias_ignore_padding(template_mask) # template_embs = self._embedding_fn( # template_ids, self.shared_tensors['syntax_embedding_table']) template_outputs = transformer.transformer_encoder( template_embs, template_bias, self.hparams) return template_outputs, template_bias
def create_t2t_transformer_encoder( x_in: "tf.Tensor", mask: "tf.Tensor", attention_weights: Dict[Text, "tf.Tensor"], hparams: "HParams", C2: float, is_training: "tf.Tensor", ) -> "tf.Tensor": """Create t2t transformer encoder.""" with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE): x = create_tf_fnn( x_in, [hparams.hidden_size], hparams.layer_prepostprocess_dropout, C2, is_training, layer_name_suffix="pre_embed", activation=None, use_bias=False, kernel_initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5), ) if hparams.multiply_embedding_mode == "sqrt_depth": x *= hparams.hidden_size**0.5 x *= tf.expand_dims(mask, -1) ( x, self_attention_bias, encoder_decoder_attention_bias, ) = transformer_prepare_encoder(x, None, hparams) x *= tf.expand_dims(mask, -1) x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) attn_bias_for_padding = None # Otherwise the encoder will just use encoder_self_attention_bias. if hparams.unidirectional_encoder: attn_bias_for_padding = encoder_decoder_attention_bias x = transformer_encoder( x, self_attention_bias, hparams, nonpadding=mask, save_weights_to=attention_weights, attn_bias_for_padding=attn_bias_for_padding, ) x *= tf.expand_dims(mask, -1) return tf.nn.dropout(tf.nn.relu(x), 1.0 - hparams.layer_prepostprocess_dropout)
def encoder(name, hparams, inputs, target_space): """Compute encoder outputs and attention bias.""" with tf.variable_scope(name, reuse=tf.AUTO_REUSE): (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias) = (transformer_prepare_encoder( inputs, target_space, hparams)) encoder_input = tf.nn.dropout( encoder_input, rate=hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder(encoder_input, encoder_self_attention_bias, hparams) return encoder_output, encoder_decoder_attention_bias
def transformer_text_encoder(inputs, space_id, hparams, name="transformer_text_enc"): """Transformer text encoder.""" with tf.variable_scope(name): x = common_layers.flatten4d3d(inputs) (encoder_input, encoder_self_attention_bias, ed) = transformer.transformer_prepare_encoder(x, space_id, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) return transformer.transformer_encoder(encoder_input, encoder_self_attention_bias, hparams), ed
def transformer_encoder(features, hparams, embed_scope=None, embed_token_fn=common_embed.embed_tokens, attention_weights=None): """Encodes a screen using Transformer. Args: features: the feature dict. hparams: the hyperparameter. embed_scope: the scope for token embedding. embed_token_fn: the embed function. attention_weights: the attention_weights dict. Returns: encoder_outputs: a Tensor of shape [batch_size, num_steps, max_object_count, hidden_size] encoder_attn_bias: A tensor of shape [batch_size, num_steps, max_object_count] """ tf.logging.info("Using Transformer screen encoder") # Remove the default positional encoding in Transformer object_embed, object_mask, encoder_attn_bias = prepare_encoder_input( features=features, hparams=hparams, embed_scope=embed_scope, embed_token_fn=embed_token_fn) with tf.variable_scope("encode_screen", reuse=tf.AUTO_REUSE): shape = tf.shape(object_embed) with tf.control_dependencies( [tf.assert_equal(shape[3], hparams.hidden_size)]): object_embed = tf.reshape( object_embed, [shape[0] * shape[1], shape[2], hparams.hidden_size]) encoder_input = tf.nn.dropout(object_embed, keep_prob=1.0 - hparams.layer_prepostprocess_dropout) self_attention_bias = tf.expand_dims(tf.expand_dims(tf.reshape( encoder_attn_bias, [shape[0] * shape[1], shape[2]]), axis=1), axis=1) encoder_output = transformer.transformer_encoder( encoder_input=encoder_input, encoder_self_attention_bias=self_attention_bias, hparams=hparams, save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled()) encoder_output = tf.reshape(encoder_output, [shape[0], shape[1], shape[2], shape[3]]) return encoder_output, object_mask, encoder_attn_bias
def te_encode(input_seq, hparams, target_space, features, name): input_seq = common_layers.flatten4d3d(input_seq) (encoder_input, encoder_self_attention_bias, _) = ( transformer_prepare_encoder(input_seq, target_space, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding(features, "input_seq")) encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output
def encode(self, features, input_key): hparams = self._hparams inputs = common_layers.flatten4d3d(features[input_key]) (encoder_input, encoder_self_attention_bias, _) = ( transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, input_key)) encoder_output = tf.reduce_mean(encoder_output, axis=1) return encoder_output
def sim_encode(inputs, target_space, hparams, features): # inputs = tf.Print(inputs, [tf.shape(inputs)], "input", summarize=10) inputs = common_layers.flatten4d3d(inputs) (encoder_input, encoder_self_attention_bias, _) = (transformer.transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs")) positional_mean = tf.nn.l2_normalize(tf.reduce_mean(encoder_output, 1), 1) # out_norm = tf.norm(positional_mean) # positional_mean = tf.Print(positional_mean , [out_norm], "enc_out: (should be b_size**0.5) ", summarize=10) # positional_mean = tf.Print(positional_mean , [tf.shape(positional_mean)], "enc_out: (should be (b_size, h_size)) ", summarize=10) return positional_mean
def forward(self, contexts_emb, contexts, abbr_inp_emb, longform_emb=None): """ :param contexts_emb: [batch_size, context_len, emb_dim] :param contexts: a list of tensors of words, [batch_size] * context_len :param abbr_inp_emb: [batch_size, 1, emb_dim] :param longform_emb: [batch_size, longform_len, emb_dim] :return: decoder_output: predicted abbr embedding, [batch_size, 1, emb_dim] """ saved_weights = {} extra_loss = None contexts_bias = common_attention.attention_bias_ignore_padding( tf.to_float( tf.equal(tf.stack(contexts, axis=1), self.voc.encode(constant.PAD)))) contexts_emb = tf.nn.dropout( contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout) abbr_inp_emb = tf.nn.dropout( abbr_inp_emb, 1.0 - self.hparams.layer_prepostprocess_dropout) # [batch_size, context_len, emb_dim] encoder_output = transformer.transformer_encoder( contexts_emb, contexts_bias, hparams=self.hparams, save_weights_to=saved_weights) # [batch_size, 1, emb_dim] decoder_output = transformer.transformer_decoder( abbr_inp_emb, encoder_output, decoder_self_attention_bias=tf.zeros( [self.model_config.batch_size, 1, 1, 1]), encoder_decoder_attention_bias=contexts_bias, hparams=self.hparams, save_weights_to=saved_weights) return decoder_output, saved_weights, extra_loss
def body(self, features): hparams = self._hparams inputs = features["inputs"] target_space = features["target_space_id"] inputs = common_layers.flatten4d3d(inputs) (encoder_input, encoder_self_attention_bias, _) = (transformer.transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs")) encoder_output = encoder_output[:, :1, :] encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output
def transformer_text_encoder(x, space_id, hparams, name="transformer_text_encoder"): """Transformer text encoder over inputs with unmasked full attention. Args: x: Tensor of shape [batch, length, hidden_dim]. space_id: int, id. hparams: Dict, hyperparameters. name: string, variable scope. Returns: x: Tensor of shape [batch, length, hidden_dim]. ed: Tensor, bias for padded tokens in the input, shape [batch, length] """ with tf.variable_scope(name): x = common_layers.flatten4d3d(x) (encoder_input, encoder_self_attention_bias, ed) = transformer.transformer_prepare_encoder(x, space_id, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) return transformer.transformer_encoder(encoder_input, encoder_self_attention_bias, hparams), ed
def vae_transformer_internal(inputs, targets, target_space, hparams): """VAE Transformer, main step used for training.""" with tf.variable_scope("vae_transformer"): is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN # Prepare inputs, targets, and k. inputs = common_layers.flatten4d3d(inputs) targets = common_layers.flatten4d3d(targets) k = 2**hparams.num_compress_steps _, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=k) # Transformer preparations and encoder. (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias ) = transformer.transformer_prepare_encoder(inputs, target_space, hparams) residual_fn = transformer.get_residual_fn(hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) encoder_output = transformer.transformer_encoder( encoder_input, residual_fn, encoder_self_attention_bias, hparams) def get_decoder_autoregressive(): """Decoder input for autoregressive computation.""" (a, b) = transformer.transformer_prepare_decoder(targets, hparams) return (a, b, tf.constant(0.0)) # 10% of the time we compress all-zeros, as will be at decoding start. prob_targets = 0.9 if is_training else 1.0 to_compress = tf.cond(tf.less(tf.random_uniform([]), prob_targets), lambda: targets, lambda: tf.zeros_like(targets)) z, kl_loss = compress_vae(to_compress, hparams, "vae") # Decompress. for i in xrange(hparams.num_compress_steps): j = hparams.num_hidden_layers - i - 1 z = decompress(z, hparams, "decompress_%d" % j) def get_decoder_from_vae(): """Decoder input computed by VAE.""" # Return decoder stuff. (a, b) = transformer.transformer_prepare_decoder( tf.squeeze(z, axis=2), hparams) return (a, b, kl_loss) # Randomize decoder inputs.. prob_do_vae = common_layers.inverse_exp_decay(40000) * 0.7 step = tf.to_float(tf.contrib.framework.get_global_step()) if not is_training: prob_do_vae = tf.cond(tf.less(step, 40000.0), lambda: tf.constant(0.0), lambda: tf.constant(1.0)) (decoder_input, decoder_self_attention_bias, kl_loss2) = tf.cond(tf.less(tf.random_uniform([]), prob_do_vae), get_decoder_from_vae, get_decoder_autoregressive) # Transformer decoder. decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, residual_fn, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) cond_self = tf.cond(tf.less(step, 30000.0), lambda: tf.constant(1.0), lambda: tf.constant(0.0)) prob_self = 0.4 if is_training else cond_self (ret, kl_loss) = tf.cond(tf.less(tf.random_uniform([]), prob_self), lambda: (z, kl_loss), lambda: (decoder_output, kl_loss2)) kl_loss *= common_layers.inverse_exp_decay(50000) * 2.0 return ret, kl_loss
def build(self, features): src_ids = features['src_ids'] trg_ids = None self.batch_size = tf.shape(src_ids)[0] if self.is_training: trg_ids = features['trg_ids'] with tf.variable_scope('src_encoder'): self.shared_tensors['src_ids'] = src_ids src_mask = tf.cast(tf.equal(src_ids, self.data.vocab.pad_id), tf.float32) src_bias = common_attention.attention_bias_ignore_padding(src_mask) self.shared_tensors['src_bias'] = src_bias self.shared_tensors['src_mask'] = src_mask src_embs = self._embedding_fn(src_ids) src_embs = common_attention.add_timing_signal_1d(src_embs) if 'syntax_gen' in self.flags.control_mode: template_comp_ids = features['template_comp_ids'] # print_op = tf.print("template_comp_ids output:", template_comp_ids) # with tf.control_dependencies([print_op]): # template_comp_ids = tf.identity(template_comp_ids) template_embs = self._embedding_fn( template_comp_ids, self.shared_tensors['syntax_embedding_table']) template_scale = tf.get_variable( 'template_scale', shape=[1, self.flags.syntax_level, 1, 1], trainable=True, dtype=tf.float32) template_embs *= template_scale template_embs = tf.reduce_mean(template_embs, axis=1) src_embs += template_embs if 'gpt2' in self.flags.model_mode: src_outputs = model.gpt2_encoder(self.hparams, src_embs, encoder_bias=src_bias) elif 't2t' in self.flags.model_mode: src_outputs = transformer.transformer_encoder( src_embs, src_bias, self.hparams) elif 'bert' in self.flags.model_mode: bert_model = BertModel( config=BertConfig.from_json_file( self.flags.bert_config_file), is_training=self.is_training, input_ids=src_ids, input_mask=1.0 - src_mask, embeddings=self.shared_tensors['word_embedding_table']) src_outputs = bert_model.get_sequence_output() else: raise ValueError('model_mode not known.') self.shared_tensors['src_outputs'] = src_outputs if self.flags.control_mode: control_ids = features['control_ids'] control_mask = tf.cast( tf.equal(control_ids, self.data.vocab.pad_id), tf.float32) control_bias = common_attention.attention_bias_ignore_padding( control_mask) control_embs = self._embedding_fn(control_ids) if 'gpt2' in self.flags.model_mode: control_outputs = model.gpt2_encoder( self.hparams, control_embs, encoder_bias=control_bias) elif 't2t' in self.flags.model_mode or 'bert' in self.flags.model_mode: control_outputs = transformer.transformer_encoder( control_embs, control_bias, self.hparams, name='control_encoder') else: raise ValueError('model_mode not known.') self.shared_tensors['control_vec'] = features['control_vec'] self.shared_tensors['control_outputs'] = control_outputs self.shared_tensors['control_bias'] = control_bias self.shared_tensors['extra_vec'] = features['extra_vec'] # if 'syntax_gen' in self.flags.control_mode: # template_comp_ids = features['template_comp_ids'] # template_comp_outputs, template_comp_bias = self.encode_syntax_template(template_comp_ids) # self.shared_tensors['template_comp_outputs'] = template_comp_outputs # self.shared_tensors['template_comp_bias'] = template_comp_bias batch_go = tf.tile( tf.expand_dims(self._embedding_fn(self.data.vocab.go_id), axis=0), [self.batch_size, 1]) batch_go_id = tf.tile( tf.constant(self.data.vocab.go_id, tf.int32, shape=[ 1, ]), [self.batch_size]) self.shared_tensors['batch_go'] = batch_go self.shared_tensors['batch_go_id'] = batch_go_id batch_syntax_go = tf.tile( tf.expand_dims(self._embedding_fn(self.data.syntax_vocab.go_id), axis=0), [self.batch_size, 1]) batch_syntax_go_id = tf.tile( tf.constant(self.data.syntax_vocab.go_id, tf.int32, shape=[ 1, ]), [self.batch_size]) self.shared_tensors['batch_syntax_go'] = batch_syntax_go self.shared_tensors['batch_syntax_go_id'] = batch_syntax_go_id outputs = {} outputs['src_ids'] = src_ids if self.flags.control_mode: outputs["control_vec"] = self.shared_tensors['control_vec'] # if 'predict' in self.flags.control_mode: # control_vec, outputs = self.classify( # outputs, # self.shared_tensors['control_vec'], # "fix_predict" in self.flags.control_mode) # self.shared_tensors['control_vec'] = control_vec if self.flags.control_mode: if "flatten" not in self.flags.control_mode: # print_op = tf.print("Debug output:", self.shared_tensors['control_vec']) # with tf.control_dependencies([print_op]): # self.shared_tensors['control_vec'] = tf.identity(self.shared_tensors['control_vec']) dupicate_copies = self.flags.dimension // self.data.control_vec_len batch_size = self.flags.train_batch_size if self.is_training else self.flags.eval_batch_size control_vec = tf.concat([ tf.reshape( tf.transpose( tf.tile( tf.expand_dims( self.shared_tensors['control_vec'][o, :], axis=0), [dupicate_copies, 1])), [1, self.flags.dimension]) for o in range(batch_size) ], axis=0) more_control_vec = tf.zeros([ batch_size, self.flags.dimension % self.data.control_vec_len ]) if not self.is_training and self.flags.beam_search_size > 1: more_control_vec = tf.zeros([ batch_size * self.flags.beam_search_size, self.flags.dimension % self.data.control_vec_len ]) self.shared_tensors['control_vec'] = tf.concat( [control_vec, more_control_vec], axis=1) else: score = tf.expand_dims(self.shared_tensors['control_vec'], axis=-1) score = tf.tile(score, [1, 1, self.flags.dimension]) self.shared_tensors['control_vec'] = score if "encoder" in self.flags.control_mode: src_outputs = self.update_embedding(src_outputs, False) self.shared_tensors['src_outputs'] = src_outputs with tf.variable_scope("trg_decoder"): if self.is_training: # Generate syntax if 'syntax_gen' in self.flags.control_mode: syntax_losses = [] template_simp_ids = features['template_simp_ids'] # print_op = tf.print("template_simp_ids output:", template_simp_ids) # with tf.control_dependencies([print_op]): # template_simp_ids = tf.identity(template_simp_ids) template_simp_ids_layers = tf.unstack(template_simp_ids, axis=1) for l_id in range(self.flags.syntax_level): template_simp_ids_layer = template_simp_ids_layers[ l_id] # print_op = tf.print("template_simp_ids_layer %s output:" % l_id, template_simp_ids_layer) # with tf.control_dependencies([print_op]): # template_simp_ids_layer = tf.identity(template_simp_ids_layer) template_simp_ids_layer_list = tf.unstack( template_simp_ids_layer, axis=1) template_simp_ids_layer_inp_list = [ batch_syntax_go_id ] + template_simp_ids_layer_list[:-1] template_simp_emb_list = self._embedding_fn( template_simp_ids_layer_inp_list, self.shared_tensors['syntax_embedding_table']) template_simp_emb = tf.stack(template_simp_emb_list, axis=1) template_mask = tf.cast( tf.equal(template_simp_ids_layers[0], self.data.vocab.pad_id), tf.float32) template_bias = common_attention.attention_bias_ignore_padding( template_mask) if l_id == 0: self.shared_tensors[ 'template_prev_simp_outputs'] = None self.shared_tensors['template_simp_bias'] = None else: template_simp_prev_ids_layers = template_simp_ids_layers[: l_id] template_simp_prev_ids = tf.stack( template_simp_prev_ids_layers, axis=1) template_simp_prev_embs = self._embedding_fn( template_simp_prev_ids, self.shared_tensors['syntax_embedding_table']) cur_template_scale = template_scale[:, :l_id, :, :] template_simp_prev_embs *= cur_template_scale template_simp_prev_embs = tf.reduce_mean( template_simp_prev_embs, axis=1) template_simp_outputs, template_simp_bias = self.encode_syntax_template( template_simp_prev_embs, template_bias) self.shared_tensors[ 'template_prev_simp_outputs'] = template_simp_outputs self.shared_tensors[ 'template_simp_bias'] = template_simp_bias syntax_outputs = self.decode_syntax_template( template_simp_emb) syntax_logits = tf.nn.conv1d( syntax_outputs, tf.expand_dims( self.shared_tensors['proj_syntax_w'], axis=0), 1, 'SAME') + tf.expand_dims(tf.expand_dims( self.shared_tensors['proj_syntax_b'], axis=0), axis=0) # syntax_gen = tf.argmax(syntax_logits, axis=-1) syntax_weight = tf.cast( tf.not_equal(template_simp_ids_layer, self.data.syntax_vocab.pad_id), tf.float32) syntax_loss = sequence_loss( logits=syntax_logits, targets=template_simp_ids_layer, weights=syntax_weight) syntax_losses.append(syntax_loss) outputs['loss_syntax'] = tf.add_n(syntax_losses) outputs['perplexity_syntax'] = tf.exp( outputs['loss_syntax']) tf.summary.scalar("loss_syntax", outputs['loss_syntax']) tf.summary.scalar("perplexity_syntax", outputs['perplexity_syntax']) template_simp_prev_ids_layers = template_simp_ids_layers template_simp_prev_ids = tf.stack( template_simp_prev_ids_layers, axis=1) template_simp_prev_embs = self._embedding_fn( template_simp_prev_ids, self.shared_tensors['syntax_embedding_table']) cur_template_scale = template_scale template_simp_prev_embs *= cur_template_scale template_simp_prev_embs = tf.reduce_mean( template_simp_prev_embs, axis=1) template_simp_outputs, template_simp_bias = self.encode_syntax_template( template_simp_prev_embs, template_bias) self.shared_tensors[ 'template_simp_outputs'] = template_simp_outputs self.shared_tensors[ 'template_simp_bias'] = template_simp_bias # Generate sentence trg_ids_list = tf.unstack(trg_ids, axis=1) trg_input_ids_list = [batch_go_id] + trg_ids_list[:-1] trg_emb_list = self._embedding_fn(trg_input_ids_list) trg_input_ids = tf.stack(trg_input_ids_list, axis=1) trg_output_ids = tf.stack(trg_ids_list, axis=1) trg_emb = tf.stack(trg_emb_list, axis=1) decoder_outputs = self.decode_srcs_to_trgs( trg_emb=trg_emb, trg_input_ids=trg_input_ids, outputs=outputs) word_logits = tf.nn.conv1d( decoder_outputs, tf.expand_dims(self.shared_tensors['proj_word_w'], axis=0), 1, 'SAME') + tf.expand_dims(tf.expand_dims( self.shared_tensors['proj_word_b'], axis=0), axis=0) word_gen = tf.argmax(word_logits, axis=-1) outputs['gen'] = word_gen outputs['logits'] = word_logits weight = tf.cast( tf.not_equal(trg_output_ids, self.data.vocab.pad_id), tf.float32) loss = sequence_loss(logits=word_logits, targets=trg_output_ids, weights=weight) outputs['loss_decoder'] = loss outputs['perplexity_decoder'] = tf.exp(loss) tf.summary.scalar("loss_decoder", outputs['loss_decoder']) tf.summary.scalar("perplexity_decoder", outputs['perplexity_decoder']) # if 'predict' in self.flags.control_mode: # # outputs['loss_length'] = outputs['loss_length'] # # outputs['loss_syntax'] = outputs['loss_syntax'] # # outputs['loss'] += outputs['loss_split'] # outputs["loss_pred"] = outputs['loss_length'] + outputs['loss_syntax'] + outputs['loss_split'] # tf.summary.scalar("loss_length", outputs['loss_length']) # tf.summary.scalar("loss_syntax", outputs['loss_syntax']) # tf.summary.scalar("loss_split", outputs['loss_split']) else: outputs['gen_src_syntax_ids'] = features['template_comp_ids'] confident_scores = [] self._tile_variables() if 'syntax_gen' in self.flags.control_mode: def symbol_to_syntax_logits_fn(gen_ids): cur_ids = tf.concat([ tf.expand_dims(batch_syntax_go_id, axis=-1), gen_ids[:, 1:] ], axis=1) cur_embs = tf.nn.embedding_lookup( self.shared_tensors['syntax_embedding_table'], cur_ids) cur_outputs = self.decode_syntax_template(cur_embs) cur_logit = tf.matmul( cur_outputs[:, -1, :], self.shared_tensors['proj_syntax_w'] ) + self.shared_tensors['proj_syntax_b'] return cur_logit template_simp_prev_ids_layers = [] for l_id in range(self.flags.syntax_level): if l_id == 0: self.shared_tensors[ 'template_prev_simp_outputs'] = None self.shared_tensors['template_simp_bias'] = None else: template_simp_prev_ids = tf.stack( template_simp_prev_ids_layers, axis=1) template_simp_prev_embs = self._embedding_fn( template_simp_prev_ids, self.shared_tensors['syntax_embedding_table']) cur_template_scale = template_scale[:, :l_id, :, :] template_simp_prev_embs *= cur_template_scale template_simp_prev_embs = tf.reduce_mean( template_simp_prev_embs, axis=1) template_mask = tf.cast( tf.equal(template_simp_prev_ids_layers[-1], self.data.vocab.pad_id), tf.float32) template_bias = common_attention.attention_bias_ignore_padding( template_mask) template_simp_outputs, template_simp_bias = self.encode_syntax_template( template_simp_prev_embs, template_bias) self.shared_tensors[ 'template_prev_simp_outputs'] = template_simp_outputs self.shared_tensors[ 'template_simp_bias'] = template_simp_bias beam_ids, beam_score = beam_search.beam_search( symbols_to_logits_fn=symbol_to_syntax_logits_fn, initial_ids=tf.ones([self.flags.eval_batch_size], tf.int32) * self.data.syntax_vocab.go_id, beam_size=self.flags.beam_search_size, decode_length=self.flags.max_syntax_trg_len, vocab_size=self.data.syntax_vocab.size(), alpha=0.6, eos_id=self.data.syntax_vocab.eos_id) top_beam_ids = beam_ids[:, 0, 1:] top_beam_ids = tf.pad( top_beam_ids, [[0, 0], [ 0, self.flags.max_syntax_trg_len - tf.shape(top_beam_ids)[1] ]]) confident_score = -beam_score[:, 0] / tf.to_float( tf.shape(top_beam_ids)[1]) confident_scores.append(confident_score) # outputs['gen_src_syntax_ids'] = features['template_comp_ids'] # outputs['gen_trg_syntax_ids'] = top_beam_ids # outputs['gen_trg_syntax_scores'] = confident_score template_simp_prev_ids_layers.append(top_beam_ids) template_simp_prev_ids = tf.stack( template_simp_prev_ids_layers, axis=1) outputs['gen_trg_syntax_ids'] = template_simp_prev_ids outputs['gen_trg_syntax_scores'] = tf.add_n( confident_scores) template_simp_prev_embs = self._embedding_fn( template_simp_prev_ids, self.shared_tensors['syntax_embedding_table']) template_simp_prev_embs *= template_scale template_simp_prev_embs = tf.reduce_mean( template_simp_prev_embs, axis=1) template_mask = tf.cast( tf.equal(template_simp_prev_ids_layers[-1], self.data.vocab.pad_id), tf.float32) template_bias = common_attention.attention_bias_ignore_padding( template_mask) template_simp_outputs, template_simp_bias = self.encode_syntax_template( template_simp_prev_embs, template_bias) self.shared_tensors[ 'template_simp_outputs'] = template_simp_outputs self.shared_tensors[ 'template_simp_bias'] = template_simp_bias def symbol_to_logits_fn(gen_ids): cur_ids = tf.concat( [tf.expand_dims(batch_go_id, axis=-1), gen_ids[:, 1:]], axis=1) cur_embs = tf.nn.embedding_lookup( self.shared_tensors['word_embedding_table'], cur_ids) cur_outputs = self.decode_srcs_to_trgs( trg_emb=cur_embs, trg_input_ids=cur_ids) cur_logit = tf.matmul( cur_outputs[:, -1, :], self.shared_tensors['proj_word_w'] ) + self.shared_tensors['proj_word_b'] return cur_logit beam_ids, beam_score = beam_search.beam_search( symbols_to_logits_fn=symbol_to_logits_fn, initial_ids=tf.ones([self.flags.eval_batch_size], tf.int32) * self.data.vocab.go_id, beam_size=self.flags.beam_search_size, decode_length=self.flags.max_trg_len, vocab_size=self.data.vocab.size() + len(self.data.vocab.more_tokens), alpha=0.6, eos_id=self.data.vocab.eos_id) top_beam_ids = beam_ids[:, 0, 1:] top_beam_ids = tf.pad( top_beam_ids, [[0, 0], [0, self.flags.max_trg_len - tf.shape(top_beam_ids)[1]]]) confident_score = -beam_score[:, 0] / tf.to_float( tf.shape(top_beam_ids)[1]) outputs['gen_trg_ids'] = top_beam_ids outputs['gen_trg_scores'] = confident_score if self.flags.control_mode: outputs['control_ids'] = features['control_ids'] return outputs
def encode_decode_task(features, hparams, train, attention_weights=None): """Model core graph for the one-shot action. Args: features: a dictionary contains "inputs" that is a tensor in shape of [batch_size, num_tokens], "verb_id_seq" that is in shape of [batch_size, num_actions], "object_spans" and "param_span" tensor in shape of [batch_size, num_actions, 2]. 0 is used as padding or non-existent values. hparams: the general hyperparameters for the model. train: the train mode. attention_weights: the dict to keep attention weights for analysis. Returns: loss_dict: the losses for training. prediction_dict: the predictions for action tuples. areas: the area encodings of the task. scope: the embedding scope. """ del train input_embeddings, scope = common_embed.embed_tokens( features["task"], hparams.task_vocab_size, hparams.hidden_size, hparams) with tf.variable_scope("encode_decode", reuse=tf.AUTO_REUSE): encoder_nonpadding = tf.minimum(tf.to_float(features["task"]), 1.0) input_embeddings = tf.multiply(tf.expand_dims(encoder_nonpadding, 2), input_embeddings) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder(input_embeddings, None, hparams, features=None)) encoder_input = tf.nn.dropout(encoder_input, keep_prob=1.0 - hparams.layer_prepostprocess_dropout) if hparams.instruction_encoder == "transformer": encoder_output = transformer.transformer_encoder( encoder_input, self_attention_bias, hparams, save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled()) else: raise ValueError("Unsupported instruction encoder %s" % (hparams.instruction_encoder)) span_rep = hparams.get("span_rep", "area") area_encodings, area_starts, area_ends = area_utils.compute_sum_image( encoder_output, max_area_width=hparams.max_span) current_shape = tf.shape(area_encodings) if span_rep == "area": area_encodings, _, _ = area_utils.compute_sum_image( encoder_output, max_area_width=hparams.max_span) elif span_rep == "basic": area_encodings = area_utils.compute_alternative_span_rep( encoder_output, input_embeddings, max_area_width=hparams.max_span, hidden_size=hparams.hidden_size, advanced=False) elif span_rep == "coref": area_encodings = area_utils.compute_alternative_span_rep( encoder_output, input_embeddings, max_area_width=hparams.max_span, hidden_size=hparams.hidden_size, advanced=True) else: raise ValueError("xyz") areas = {} areas["encodings"] = area_encodings areas["starts"] = area_starts areas["ends"] = area_ends with tf.control_dependencies([ tf.print("encoder_output", tf.shape(encoder_output)), tf.assert_equal(current_shape, tf.shape(area_encodings), summarize=100) ]): paddings = tf.cast(tf.less(self_attention_bias, -1), tf.int32) padding_sum, _, _ = area_utils.compute_sum_image( tf.expand_dims(tf.squeeze(paddings, [1, 2]), 2), max_area_width=hparams.max_span) num_areas = common_layers.shape_list(area_encodings)[1] area_paddings = tf.reshape(tf.minimum(tf.to_float(padding_sum), 1.0), [-1, num_areas]) areas["bias"] = area_paddings decoder_nonpadding = tf.to_float( tf.greater(features["verb_refs"][:, :, 1], features["verb_refs"][:, :, 0])) if hparams.instruction_encoder == "lstm": hparams_decoder = copy.copy(hparams) hparams_decoder.set_hparam("pos", "none") else: hparams_decoder = hparams decoder_input, decoder_self_attention_bias = _prepare_decoder_input( area_encodings, decoder_nonpadding, features, hparams_decoder, embed_scope=scope) decoder_input = tf.nn.dropout(decoder_input, keep_prob=1.0 - hparams.layer_prepostprocess_dropout) if hparams.instruction_decoder == "transformer": decoder_output = transformer.transformer_decoder( decoder_input=decoder_input, encoder_output=encoder_output, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=encoder_decoder_attention_bias, hparams=hparams_decoder) else: raise ValueError("Unsupported instruction encoder %s" % (hparams.instruction_encoder)) return decoder_output, decoder_nonpadding, areas, scope
def create_model(self): with tf.variable_scope('variables'): contexts = [] for _ in range(self.model_config.max_context_len): contexts.append( tf.zeros(self.model_config.batch_size, tf.int32, name='context_input')) sense_inps, abbr_sinps, abbr_einps = [], [], [] for _ in range(self.model_config.max_abbrs): sense_inps.append( tf.zeros(self.model_config.batch_size, tf.int32, name='sense_input')) abbr_sinps.append( tf.zeros([self.model_config.batch_size], tf.int32, name='sense__sinput')) abbr_einps.append( tf.zeros([self.model_config.batch_size], tf.int32, name='sense_einput')) num_abbr = tf.zeros([self.model_config.batch_size], tf.float32, name='num_abbr') with tf.variable_scope('model'): contexts_emb = tf.stack(self.embedding_fn(contexts, self.embs), axis=1) contexts_emb_bias = common_attention.attention_bias_ignore_padding( tf.to_float( tf.equal(tf.stack(contexts, axis=1), self.data.voc.encode(constant.PAD)))) contexts_emb = tf.nn.dropout( contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout) encoder_outputs = transformer.transformer_encoder( contexts_emb, contexts_emb_bias, self.hparams) if self.model_config.aggregate_mode == 'selfattn': selfattn_w = tf.get_variable( 'selfattn_w', [1, self.model_config.dimension, 1], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) selfattn_b = tf.get_variable( 'selfattn_b', [1, 1, 1], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) weight = tf.nn.tanh( tf.nn.conv1d(encoder_outputs, selfattn_w, 1, 'SAME') + selfattn_b) encoder_outputs *= weight aggregate_state = tf.reduce_mean(encoder_outputs, axis=1) else: aggregate_state = tf.reduce_mean(encoder_outputs, axis=1) with tf.variable_scope('pred'): proj_w = tf.get_variable( 'proj_w', [self.model_config.dimension, self.data.sen_cnt], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) proj_b = tf.get_variable( 'proj_b', [self.data.sen_cnt], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) losses = [] preds = [] for abbr_id in range(self.model_config.max_abbrs): abbr_sinp = abbr_sinps[abbr_id] abbr_einp = abbr_einps[abbr_id] sense_inp = sense_inps[abbr_id] mask = tf.to_float( tf.sequence_mask( abbr_einp, self.data.sen_cnt)) - tf.to_float( tf.sequence_mask(abbr_sinp, self.data.sen_cnt)) logits = tf.matmul(aggregate_state, proj_w) + proj_b logits *= mask loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=sense_inp) loss_mask = tf.to_float(tf.not_equal(sense_inp, 0)) loss *= loss_mask losses.append(loss) preds.append(tf.nn.top_k(logits, k=5, sorted=True)[1]) preds = tf.stack(preds, axis=1) obj = { 'contexts': contexts, 'sense_inp': sense_inps, 'abbr_sinp': abbr_sinps, 'abbr_einp': abbr_einps, 'num_abbr': num_abbr, 'preds': preds, } return tf.add_n(losses) / num_abbr, obj
def context_encoder(self, contexts_emb, contexts, abbr_inp_emb=None): """ :param contexts_emb: a tensor of [batch_size, max_context_len, emb_dim] :param contexts: a list of [max_context_len, batch_size] :param abbr_inp_emb: a tensor of [batch_size, context_len, emb_dim], in transformer_abbr_encoder :return: encoder_output: [batch_size, context_len, channel_dim] weights: a list of multihead weights, num_layer elements, each of which is [batch_size, num_head, context_len, context_len] extra_loss: None """ weights = {} # Create an bias tensor as mask (big neg values for padded part), input=[batch_size, context_len], output=[batch_size, 1, 1, context_len] contexts_bias = common_attention.attention_bias_ignore_padding( tf.to_float( tf.equal(tf.stack(contexts, axis=1), self.voc.encode(constant.PAD)))) # add dropout to context input [batch_size, max_context_len, emb_dim] contexts_emb = tf.nn.dropout( contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout) # get the output vector of transformer, [batch_size, context_len, channel_dim] # encoder_ouput = transformer.transformer_encoder_abbr( # contexts_emb, contexts_bias, abbr_inp_emb, # tf.zeros([self.model_config.batch_size,1,1,1]), self.hparams, # save_weights_to=weights) if self.model_config.encoder_mode == 't2t': encoder_output = transformer.transformer_encoder( contexts_emb, contexts_bias, self.hparams, save_weights_to=weights) extra_loss = None elif self.model_config.encoder_mode == 'ut2t': encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder( contexts_emb, contexts_bias, self.hparams, save_weights_to=weights) enc_ponder_times, enc_remainders = extra_output extra_loss = (self.hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) elif self.model_config.encoder_mode == 'abbr_ut2t': encoder_output, extra_output = universal_transformer_util.universal_transformer_encoder( contexts_emb, contexts_bias, self.hparams, save_weights_to=weights) enc_ponder_times, enc_remainders = extra_output extra_loss = (self.hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) encoder_ouput2, extra_output2 = universal_transformer_util.universal_transformer_decoder( abbr_inp_emb, encoder_output, tf.zeros([self.model_config.batch_size, 1, 1, 1]), contexts_bias, self.hparams) enc_ponder_times2, enc_remainders2 = extra_output2 extra_loss2 = (self.hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times2 + enc_remainders2)) extra_loss += extra_loss2 else: raise ValueError('Unknow encoder_mode.') return encoder_output, weights, extra_loss
def hierarchical_context_encoder(encoder_input, encoder_self_attention_bias, contexts, context_self_attention_biases, features, hparams, name="discourse_aware_encoder", save_weights_to=None, make_image_summary=True, losses=None): input_x = encoder_input context_xs = {} for context_name in contexts: context_xs[context_name] = contexts[context_name] context_paddings = {} context_nonpaddings = {} context_pad_removers = {} attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): input_padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) input_nonpadding = 1.0 - input_padding for context_name in context_self_attention_biases: context_paddings[ context_name] = common_attention.attention_bias_to_padding( context_self_attention_biases[context_name]) context_nonpaddings[ context_name] = 1.0 - context_paddings[context_name] input_pad_remover = None for context_name in context_paddings: context_pad_removers[context_name] = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): input_pad_remover = expert_utils.PadRemover(input_padding) for context_name in context_paddings: context_pad_removers[context_name] = expert_utils.PadRemover( context_paddings[context_name]) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> num_hidden_layers - 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", hparams.num_hidden_layers - 1) encoder_output = transformer_with_contexts_layers.transformer_encoder( input_x, encoder_self_attention_bias, temp_hparam, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=save_weights_to, make_image_summary=make_image_summary) context_encoded_outputs = {} for context_name in context_xs: context_encoded_outputs[ context_name] = transformer_with_contexts_layers.transformer_encoder( context_xs[context_name], context_self_attention_biases[context_name], temp_hparam, nonpadding=features_to_nonpadding(features, context_name), save_weights_to=save_weights_to, make_image_summary=make_image_summary) with tf.variable_scope("hierarchical_context_encoder", reuse=tf.AUTO_REUSE): for context_name in context_encoded_outputs: # self attention feed-forward _y = ffn_self_attention_layer( context_encoded_outputs[context_name], hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, name="attentive_sum") # mean over sequence length context_encoded_outputs[context_name] = tf.reduce_mean( _y, axis=1, keep_dims=True) encoded_contexts = [ context_encoded_outputs[context_name] for context_name in context_encoded_outputs ] encoded_contexts = tf.concat(encoded_contexts, axis=1) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", 1) context_padding = common_attention.embedding_to_padding( encoded_contexts) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) encoded_contexts = transformer_encoder(encoded_contexts, ignore_padding, temp_hparam) with tf.variable_scope("encoder/layer_%d" % hparams.num_hidden_layers, reuse=tf.AUTO_REUSE): with tf.variable_scope("context_input_attention"): context_padding = common_attention.embedding_to_padding( encoded_contexts) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) _y = common_attention.multihead_attention( common_layers.layer_preprocess(encoder_output, hparams), encoded_contexts, ignore_padding, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) encoded_contexts = common_layers.layer_postprocess( encoder_output, _y, hparams) with tf.variable_scope("input_self_attention"): _y = common_attention.multihead_attention( common_layers.layer_preprocess(encoder_output, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) encoder_output = common_layers.layer_postprocess( encoder_output, _y, hparams) with tf.variable_scope("gated_sum"): _depth = common_layers.shape_list(encoder_output)[-1] gate = tf.layers.dense(tf.concat( [encoded_contexts, encoder_output], axis=-1), _depth, activation=tf.nn.sigmoid) if save_weights_to: save_weights_to["gated_sum"] = gate encoder_output = gate * encoder_output + ( 1. - gate) * encoded_contexts with tf.variable_scope("ffn"): _y = transformer_ffn_layer(common_layers.layer_preprocess( encoder_output, hparams), hparams, input_pad_remover, conv_padding="SAME", nonpadding_mask=input_nonpadding, losses=losses) encoder_output = common_layers.layer_postprocess( encoder_output, _y, hparams) return common_layers.layer_preprocess(encoder_output, hparams)
def create_model(self): with tf.variable_scope('variables'): abstr_ph = [] for _ in range(self.model_config.max_abstr_len): abstr_ph.append(tf.zeros(self.model_config.batch_size, tf.int32, name='abstract_input')) kwords_ph = [] for _ in range(self.model_config.max_cnt_kword): kword = [] for _ in range(self.model_config.max_kword_len): kword.append(tf.zeros(self.model_config.batch_size, tf.int32, name='kword_input')) kwords_ph.append(kword) # Train for length control if self.is_train: kword_occupies_ph = [] for _ in range(self.model_config.max_cnt_kword): kword_occupies_ph.append( tf.zeros(self.model_config.batch_size, tf.float32, name='kword_occupy_input')) emb_abstr, emb_kword, proj_w, proj_b = self.get_embedding() abstr = tf.stack(self.embedding_fn(abstr_ph, emb_abstr), axis=1) kwords = [] for kword_idx in range(self.model_config.max_cnt_kword): kwords.append(self.embedding_fn(kwords_ph[kword_idx], emb_kword)) with tf.variable_scope('model_encoder'): if self.hparams.pos == 'timing': abstr = common_attention.add_timing_signal_1d(abstr) encoder_embed_inputs = tf.nn.dropout(abstr, 1.0 - self.hparams.layer_prepostprocess_dropout) abstr_bias = common_attention.attention_bias_ignore_padding( tf.to_float(tf.equal(tf.stack(abstr_ph, axis=1), self.voc_kword.encode(constant.SYMBOL_PAD)))) abstr_outputs = transformer.transformer_encoder( encoder_embed_inputs, abstr_bias, self.hparams) losses = [] targets = [] pred_occupies = [] obj = {} hist_vector = None if 'kp_attn' in self.model_config.cov_mode: hist_vector = tf.zeros( [self.model_config.batch_size, 1, self.model_config.dimension,]) with tf.variable_scope('model_decoder'): if self.model_config.subword_vocab_size: go_id = self.voc_kword.encode(constant.SYMBOL_GO)[0] else: go_id = self.voc_kword.encode(constant.SYMBOL_GO) batch_go = tf.tile( tf.expand_dims(self.embedding_fn(go_id, emb_kword), axis=0), [self.model_config.batch_size, 1]) for kword_idx in range(self.model_config.max_cnt_kword): if self.is_train: kword = kwords[kword_idx][:-1] kword_ph = kwords_ph[kword_idx] kword_output, kword_output_list = self.decode_step( kword, abstr_outputs, abstr_bias, batch_go, hist_vector=hist_vector) kword_logit_list = [self.output_to_logit(o, proj_w, proj_b) for o in kword_output_list] kword_target_list = [tf.argmax(o, output_type=tf.int32, axis=-1) for o in kword_logit_list] kword_lossbias = [ tf.to_float(tf.not_equal(d, self.voc_kword.encode(constant.SYMBOL_PAD))) for d in kword_ph] kword_lossbias = tf.stack(kword_lossbias, axis=1) if self.model_config.number_samples > 0: loss_fn = tf.nn.sampled_softmax_loss else: loss_fn = None loss = sequence_loss(logits=tf.stack(kword_logit_list, axis=1), targets=tf.stack(kword_ph, axis=1), weights=kword_lossbias, softmax_loss_function=loss_fn, w=proj_w, b=proj_b, decoder_outputs=tf.stack(kword_output_list, axis=1), number_samples=self.model_config.number_samples ) kword_target = tf.stack(kword_target_list, axis=1) targets.append(kword_target) if 'kp_attn' in self.model_config.cov_mode: kword_embed = self.embedding_fn(kword_ph, emb_kword) hist_vector += tf.expand_dims(tf.reduce_mean( tf.stack(kword_embed, axis=1), axis=1), axis=1) # Train for length control pred_occupy = self.get_pred_occupy_logit(hist_vector, abstr_outputs) occupy_loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=pred_occupy, labels=kword_occupies_ph[kword_idx]) loss += tf.reduce_mean(occupy_loss) pred_occupies.append(pred_occupy) losses.append(loss) else: loss, kword_target = self.transformer_beam_search( abstr_outputs, abstr_bias, emb_kword, proj_w, proj_b, hist_vector=hist_vector) targets.append(kword_target) losses = loss if 'kp_attn' in self.model_config.cov_mode: kword_embed = self.embedding_fn(kword_target, emb_kword) hist_vector += tf.expand_dims(tf.reduce_mean(kword_embed, axis=1), axis=1) pred_occupy = tf.round(tf.sigmoid(self.get_pred_occupy_logit(hist_vector, abstr_outputs))) pred_occupies.append(pred_occupy) tf.get_variable_scope().reuse_variables() if targets: obj['targets'] = tf.stack(targets, axis=1) obj['abstr_ph'] = abstr_ph obj['kwords_ph'] = kwords_ph if self.is_train: obj['kword_occupies_ph'] = kword_occupies_ph pred_occupies = tf.stack(pred_occupies, axis=1) obj['pred_occupies'] = pred_occupies if type(losses) is list: losses = tf.add_n(losses) return losses, obj
def create_model(self): with tf.variable_scope('variables'): abstr_ph = [] for _ in range(self.model_config.max_abstr_len): abstr_ph.append( tf.zeros(self.model_config.batch_size, tf.int32, name='abstract_input')) kwords_ph = [] for _ in range(self.model_config.max_cnt_kword): kword = [] for _ in range(self.model_config.max_kword_len): kword.append( tf.zeros(self.model_config.batch_size, tf.int32, name='kword_input')) kwords_ph.append(kword) emb_abstr, emb_kword, proj_w, proj_b = self.get_embedding() abstr = tf.stack(self.embedding_fn(abstr_ph, emb_abstr), axis=1) kwords = [] for kword_idx in range(self.model_config.max_cnt_kword): kwords.append( self.embedding_fn(kwords_ph[kword_idx], emb_kword)) with tf.variable_scope('model_encoder'): if self.hparams.pos == 'timing': abstr = common_attention.add_timing_signal_1d(abstr) encoder_embed_inputs = tf.nn.dropout( abstr, 1.0 - self.hparams.layer_prepostprocess_dropout) abstr_bias = common_attention.attention_bias_ignore_padding( tf.to_float( tf.equal(tf.stack(abstr_ph, axis=1), self.voc_kword.encode(constant.SYMBOL_PAD)))) abstr_outputs = transformer.transformer_encoder( encoder_embed_inputs, abstr_bias, self.hparams) if 'tuzhaopeng' in self.model_config.cov_mode: attn_stick = tf.ones([ self.model_config.batch_size, self.model_config.num_heads, 1, self.model_config.dimension / self.model_config.num_heads ], tf.float32, 'attn_memory') losses = [] targets = [] obj = {} with tf.variable_scope('model_decoder'): for kword_idx in range(self.model_config.max_cnt_kword): if self.is_train: kword = kwords[kword_idx][:-1] kword_ph = kwords_ph[kword_idx] kword_output_list, new_attn_stick = self.decode_step( kword, abstr_outputs, abstr_bias, attn_stick) kword_logit_list = [ self.output_to_logit(o, proj_w, proj_b) for o in kword_output_list ] kword_target_list = [ tf.argmax(o, output_type=tf.int32, axis=-1) for o in kword_logit_list ] attn_stick = new_attn_stick if self.model_config.number_samples > 0: loss_fn = tf.nn.sampled_softmax_loss else: loss_fn = None kword_lossbias = [ tf.to_float( tf.not_equal( d, self.voc_kword.encode(constant.SYMBOL_PAD))) for d in kword_ph ] kword_lossbias = tf.stack(kword_lossbias, axis=1) loss = sequence_loss( logits=tf.stack(kword_logit_list, axis=1), targets=tf.stack(kword_ph, axis=1), weights=kword_lossbias, softmax_loss_function=loss_fn, w=proj_w, b=proj_b, decoder_outputs=tf.stack(kword_output_list, axis=1), number_samples=self.model_config.number_samples) targets.append(tf.stack(kword_target_list, axis=1)) if 'tuzhaopeng' in self.model_config.cov_mode and 'kp_attn' in self.model_config.cov_mode: target_emb = tf.stack(self.embedding_fn( kword_target_list, emb_kword), axis=1) target_emb = common_attention.split_heads( target_emb, self.model_config.num_heads) target_emb = tf.reduce_mean(target_emb, axis=2) target_emb_trans = tf.get_variable( 'dim_weight_trans', shape=[ 1, target_emb.get_shape()[-1].value, target_emb.get_shape()[-1].value ], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) target_emb = tf.nn.conv1d(target_emb, target_emb_trans, 1, 'SAME') target_emb = tf.expand_dims(target_emb, axis=2) attn_stick += target_emb losses.append(loss) else: if self.model_config.beam_search_size > 0: loss, target, new_attn_stick = self.transformer_beam_search( abstr_outputs, abstr_bias, emb_kword, proj_w, proj_b, attn_stick=attn_stick) else: loss, target, new_attn_stick = self.greed_search( kword_idx, abstr_outputs, abstr_bias, emb_kword, proj_w, proj_b, attn_stick=attn_stick) targets.append(target) losses = loss attn_stick = new_attn_stick if 'tuzhaopeng' in self.model_config.cov_mode and 'kp_attn' in self.model_config.cov_mode: target.set_shape([ self.model_config.batch_size, self.model_config.max_kword_len ]) target_list = tf.unstack(target, axis=1) target_emb = tf.stack(self.embedding_fn( target_list, emb_kword), axis=1) target_emb = common_attention.split_heads( target_emb, self.model_config.num_heads) target_emb = tf.reduce_mean(target_emb, axis=2) target_emb_trans = tf.get_variable( 'dim_weight_trans', shape=[ 1, target_emb.get_shape()[-1].value, target_emb.get_shape()[-1].value ], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) target_emb = tf.nn.conv1d(target_emb, target_emb_trans, 1, 'SAME') target_emb = tf.expand_dims(target_emb, axis=2) attn_stick += target_emb tf.get_variable_scope().reuse_variables() if targets: obj['targets'] = tf.stack(targets, axis=1) obj['abstr_ph'] = abstr_ph obj['kwords_ph'] = kwords_ph obj['attn_stick'] = attn_stick if type(losses) is list: losses = tf.add_n(losses) return losses, obj
def transformer_fn(self, sentence_complex_input_placeholder, emb_complex, sentence_simple_input_placeholder, emb_simple, w, b, rule_id_input_placeholder, rule_target_input_placeholder, mem_contexts, mem_outputs, global_step, score, comp_features, obj): encoder_mask = tf.to_float( tf.equal(tf.stack(sentence_complex_input_placeholder, axis=1), self.data.vocab_complex.encode(constant.SYMBOL_PAD))) encoder_attn_bias = common_attention.attention_bias_ignore_padding(encoder_mask) obj_tensors = {} train_mode = self.model_config.train_mode if self.model_config.bert_mode: # Leave space for decoder when static seq gpu_id = 0 if train_mode == 'static_seq' or train_mode == 'static_self-critical' or 'direct' in self.model_config.memory else 1 with tf.device('/device:GPU:%s' % gpu_id): sentence_complex_input = tf.stack(sentence_complex_input_placeholder, axis=1) bert_model = BertModel( BertConfig.from_json_file(self.model_config.bert_config), self.is_train, sentence_complex_input, input_mask=1.0-encoder_mask, token_type_ids=None, use_one_hot_embeddings=False) encoder_embed_inputs = bert_model.embedding_output encoder_outputs = bert_model.sequence_output emb_complex = bert_model.embedding_table # update emb complex if (self.model_config.tie_embedding == 'all' or self.model_config.tie_embedding == 'enc_dec'): emb_simple = bert_model.embedding_table if (self.model_config.tie_embedding == 'all' or self.model_config.tie_embedding == 'dec_out'): emb_w_proj = tf.get_variable( 'emb_w_proj', shape=[self.model_config.dimension, self.model_config.dimension], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) w = tf.matmul(bert_model.embedding_table, emb_w_proj) if 'direct' in self.model_config.memory: with tf.device('/device:GPU:1'): direct_mask = tf.to_float( tf.equal(tf.stack(rule_target_input_placeholder, axis=1), self.data.vocab_complex.encode(constant.SYMBOL_PAD))) direct_bert_model = BertModel( BertConfig.from_json_file(self.model_config.bert_config), self.is_train, tf.stack(rule_target_input_placeholder, axis=1), input_mask=1.0 - direct_mask, token_type_ids=None, use_one_hot_embeddings=False, embedding_table=emb_simple, scope='direct') direct_bert_output = direct_bert_model.sequence_output obj_tensors['direct_bert_bias'] = common_attention.attention_bias_ignore_padding(direct_mask) obj_tensors['direct_bert_output'] = direct_bert_output else: encoder_embed_inputs = tf.stack( self.embedding_fn(sentence_complex_input_placeholder, emb_complex), axis=1) if self.hparams.pos == 'timing': encoder_embed_inputs = common_attention.add_timing_signal_1d(encoder_embed_inputs) print('Use positional encoding in encoder text.') if self.model_config.subword_vocab_size and self.model_config.seg_mode: encoder_embed_inputs = common_attention.add_positional_embedding( encoder_embed_inputs, 100, 'seg_embedding', positions=obj['line_comp_segids']) print('Add segment embedding.') with tf.variable_scope('transformer_encoder'): encoder_embed_inputs = tf.nn.dropout(encoder_embed_inputs, 1.0 - self.hparams.layer_prepostprocess_dropout) if self.model_config.architecture == 'ut2t': encoder_outputs, encoder_extra_output = universal_transformer_util.universal_transformer_encoder( encoder_embed_inputs, encoder_attn_bias, self.hparams) enc_ponder_times, enc_remainders = encoder_extra_output extra_encoder_loss = ( self.hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) if self.is_train: obj_tensors['extra_encoder_loss'] = extra_encoder_loss else: encoder_outputs = transformer.transformer_encoder( encoder_embed_inputs, encoder_attn_bias, self.hparams) # Update score based on multiplier score, pred_score_tuple = self.update_score( score, encoder_outputs=encoder_outputs, encoder_mask=tf.to_float( tf.not_equal(tf.stack(sentence_complex_input_placeholder, axis=1), self.data.vocab_complex.encode(constant.SYMBOL_PAD))), comp_features=comp_features) encoder_outputs = self.update_encoder_embedding(encoder_outputs, score) encoder_embed_inputs_list = tf.unstack(encoder_embed_inputs, axis=1) with tf.variable_scope('transformer_decoder', reuse=tf.AUTO_REUSE): if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode: go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)[0] else: go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO) batch_go = tf.tile( tf.expand_dims(self.embedding_fn(go_id, emb_simple), axis=0), [self.model_config.batch_size, 1]) # For static_seq train_mode if self.model_config.npad_mode == 'static_seq': with tf.variable_scope('npad'): npad_w = tf.get_variable( 'npad_w', shape=[1, self.model_config.dimension, self.model_config.dimension], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) obj_tensors['npad_w'] = npad_w if self.is_train and (train_mode == 'teacher' or train_mode == 'teachercritical'or train_mode == 'teachercriticalv2'): # General train print('Use Generally Process.') decoder_embed_inputs_list = self.embedding_fn( sentence_simple_input_placeholder[:-1], emb_simple) final_output, decoder_output, cur_context = self.decode_step( decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, batch_go, obj_tensors) decoder_logit = ( tf.nn.conv1d(final_output, tf.expand_dims(tf.transpose(w), axis=0), 1, 'SAME') + tf.expand_dims(tf.expand_dims(b, axis=0), axis=0)) decoder_target_list = [] decoder_logit_list = tf.unstack(decoder_logit, axis=1) for logit in decoder_logit_list: decoder_target_list.append(tf.argmax(logit, output_type=tf.int32, axis=-1)) decoder_output_list = [ tf.squeeze(d, 1) for d in tf.split(decoder_output, self.model_config.max_simple_sentence, axis=1)] final_output_list = [ tf.squeeze(d, 1) for d in tf.split(final_output, self.model_config.max_simple_sentence, axis=1)] if self.model_config.pointer_mode: segment_mask = None if 'line_comp_segids' in obj: segment_mask = obj['line_comp_segids'] decoder_logit_list = word_distribution( decoder_logit_list, decoder_output_list, encoder_outputs, encoder_embed_inputs, sentence_complex_input_placeholder, obj_tensors, self.model_config, self.data, segment_mask) elif self.is_train and (train_mode == 'static_seq' or train_mode == 'static_self-critical'): decoder_target_list = [] decoder_logit_list = [] decoder_embed_inputs_list = [] # Will Override for following 3 lists final_output_list = [] decoder_output_list = [] contexts = [] sample_target_list = [] sample_logit_list = [] gpu_assign_interval = int(self.model_config.max_simple_sentence / 3) for step in range(self.model_config.max_simple_sentence): gpu_id = int(step / gpu_assign_interval) if gpu_id > 3: gpu_id = 3 gpu_id += 1 with tf.device('/device:GPU:%s' % gpu_id): print('Step%s with GPU%s' % (step, gpu_id)) final_outputs, _, cur_context = self.decode_step( decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, batch_go, obj_tensors) final_output_list = [ tf.squeeze(d, 1) for d in tf.split(final_outputs, step+1, axis=1)] final_output = final_output_list[-1] # if self.model_config.npad_mode == 'static_seq': # final_output = tf.matmul(final_output, npad_w) last_logit_list = self.output_to_logit(final_output, w, b) last_target_list = tf.argmax(last_logit_list, output_type=tf.int32, axis=-1) decoder_logit_list.append(last_logit_list) decoder_target_list.append(last_target_list) decoder_embed_inputs_list.append( tf.stop_gradient(self.embedding_fn(last_target_list, emb_simple))) if train_mode == 'static_self-critical': last_sample_list = tf.multinomial(last_logit_list, 1) sample_target_list.append(last_sample_list) indices = tf.stack( [tf.range(0, self.model_config.batch_size, dtype=tf.int64), tf.squeeze(last_sample_list)], axis=-1) sample_logit_list.append(tf.gather_nd(tf.nn.softmax(last_logit_list), indices)) else: # Beam Search print('Use Beam Search with Beam Search Size %d.' % self.model_config.beam_search_size) return self.transformer_beam_search(encoder_outputs, encoder_attn_bias, encoder_embed_inputs_list, sentence_complex_input_placeholder, emb_simple, w, b, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, obj, obj_tensors) gt_target_list = sentence_simple_input_placeholder output = ModelOutput( contexts=cur_context if 'rule' in self.model_config.memory else None, encoder_outputs=encoder_outputs, decoder_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None, final_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None, decoder_logit_list=decoder_logit_list if train_mode != 'dynamic_self-critical' else None, gt_target_list=gt_target_list, encoder_embed_inputs_list=tf.unstack(encoder_embed_inputs, axis=1), decoder_target_list=decoder_target_list, sample_logit_list=sampled_logit_list if train_mode == 'dynamic_self-critical' else None, sample_target_list=sampled_target_list if train_mode == 'dynamic_self-critical' else None, pred_score_tuple=pred_score_tuple if 'pred' in self.model_config.tune_mode else None, obj_tensors=obj_tensors, ) return output
def transformer_fn(self, sentence_complex_input_placeholder, emb_complex, sentence_simple_input_placeholder, emb_simple, w, b, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step): encoder_embed_inputs = tf.stack(self.embedding_fn( sentence_complex_input_placeholder, emb_complex), axis=1) encoder_attn_bias = common_attention.attention_bias_ignore_padding( tf.to_float( tf.equal(tf.stack(sentence_complex_input_placeholder, axis=1), self.data.vocab_complex.encode(constant.SYMBOL_PAD)))) if self.hparams.pos == 'timing': encoder_embed_inputs = common_attention.add_timing_signal_1d( encoder_embed_inputs) print('Use positional encoding in encoder text.') with tf.variable_scope('transformer_encoder'): encoder_embed_inputs = tf.nn.dropout( encoder_embed_inputs, 1.0 - self.hparams.layer_prepostprocess_dropout) encoder_outputs = transformer.transformer_encoder( encoder_embed_inputs, encoder_attn_bias, self.hparams) encoder_embed_inputs_list = tf.unstack(encoder_embed_inputs, axis=1) with tf.variable_scope('transformer_decoder'): train_mode = self.model_config.train_mode if self.is_train and (train_mode == 'teacher' or train_mode == 'teachercritical' or train_mode == 'teachercriticalv2'): # General train print('Use Generally Process.') decoder_embed_inputs_list = self.embedding_fn( sentence_simple_input_placeholder[:-1], emb_simple) final_output_list, decoder_output_list, cur_context = self.decode_step( decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step) decoder_logit_list = [ self.output_to_logit(o, w, b) for o in final_output_list ] decoder_target_list = [ tf.argmax(o, output_type=tf.int32, axis=-1) for o in decoder_logit_list ] elif self.is_train and train_mode == 'dynamic_self-critical': decoder_target_tensor = tf.TensorArray( tf.int32, size=0, dynamic_size=True, clear_after_read=False, element_shape=[ self.model_config.batch_size, ]) sampled_target_tensor = tf.TensorArray( tf.int32, size=0, dynamic_size=True, clear_after_read=False, element_shape=[ self.model_config.batch_size, ]) sampled_logit_tensor = tf.TensorArray( tf.float32, size=0, dynamic_size=True, clear_after_read=False, element_shape=[ self.model_config.batch_size, ]) def _is_finished(step, decoder_target_tensor, sampled_target_tensor, sampled_logit_tensor): return tf.less(step, self.model_config.max_simple_sentence) def _recursive(step, decoder_target_tensor, sampled_target_tensor, sampled_logit_tensor): decoder_target_stack = tf.transpose( decoder_target_tensor.stack(), perm=[1, 0]) def get_empty_emb(): decoder_emb_inputs = tf.zeros([ self.model_config.batch_size, 1, self.model_config.dimension ]) return decoder_emb_inputs def get_emb(): batch_go = tf.zeros([ self.model_config.batch_size, 1, self.model_config.dimension ]) decoder_emb_inputs = tf.concat([ batch_go, tf.gather(emb_simple, decoder_target_stack) ], axis=1) return decoder_emb_inputs decoder_emb_inputs = tf.cond(tf.equal(step, 0), lambda: get_empty_emb(), lambda: get_emb()) final_outputs, _, _ = self.decode_inputs_to_outputs( decoder_emb_inputs, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step) final_output = final_outputs[:, -1, :] decoder_logit = tf.add( tf.matmul(final_output, tf.transpose(w)), b) decoder_target = tf.stop_gradient( tf.argmax(decoder_logit, output_type=tf.int32, axis=-1)) sampled_target = tf.cast( tf.squeeze(tf.multinomial(decoder_logit, 1), axis=1), tf.int32) indices = tf.stack([ tf.range( 0, self.model_config.batch_size, dtype=tf.int32), tf.squeeze(sampled_target) ], axis=-1) logit_unit = tf.gather_nd( tf.nn.softmax(decoder_logit, axis=1), indices) decoder_target_tensor = decoder_target_tensor.write( step, decoder_target) sampled_target_tensor = sampled_target_tensor.write( step, sampled_target) sampled_logit_tensor = sampled_logit_tensor.write( step, logit_unit) return step + 1, decoder_target_tensor, sampled_target_tensor, sampled_logit_tensor step = tf.constant(0) (_, decoder_target_tensor, sampled_target_tensor, sampled_logit_tensor) = tf.while_loop( _is_finished, _recursive, [ step, decoder_target_tensor, sampled_target_tensor, sampled_logit_tensor ], back_prop=True, parallel_iterations=1, swap_memory=False) decoder_target_tensor = decoder_target_tensor.stack() decoder_target_tensor.set_shape([ self.model_config.max_simple_sentence, self.model_config.batch_size ]) decoder_target_tensor = tf.transpose(decoder_target_tensor, perm=[1, 0]) decoder_target_list = tf.unstack(decoder_target_tensor, axis=1) sampled_target_tensor = sampled_target_tensor.stack() sampled_target_tensor.set_shape([ self.model_config.max_simple_sentence, self.model_config.batch_size ]) sampled_target_tensor = tf.transpose(sampled_target_tensor, perm=[1, 0]) sampled_target_list = tf.unstack(sampled_target_tensor, axis=1) sampled_logit_tensor = sampled_logit_tensor.stack() sampled_logit_tensor.set_shape([ self.model_config.max_simple_sentence, self.model_config.batch_size ]) sampled_logit_tensor = tf.transpose(sampled_logit_tensor, perm=[1, 0]) sampled_logit_list = tf.unstack(sampled_logit_tensor, axis=1) else: # Beam Search print('Use Beam Search with Beam Search Size %d.' % self.model_config.beam_search_size) return self.transformer_beam_search( encoder_outputs, encoder_attn_bias, encoder_embed_inputs_list, sentence_complex_input_placeholder, emb_simple, w, b, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step) gt_target_list = sentence_simple_input_placeholder output = ModelOutput( contexts=cur_context if 'rule' in self.model_config.memory else None, encoder_outputs=encoder_outputs, decoder_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None, final_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None, decoder_logit_list=decoder_logit_list if train_mode != 'dynamic_self-critical' else None, gt_target_list=gt_target_list, encoder_embed_inputs_list=tf.unstack(encoder_embed_inputs, axis=1), decoder_target_list=decoder_target_list, sample_logit_list=sampled_logit_list if train_mode == 'dynamic_self-critical' else None, sample_target_list=sampled_target_list if train_mode == 'dynamic_self-critical' else None) return output