def build_attention_model(params, src_vocab, trg_vocab, source_placeholders, target_placeholders, beam_size=1, mode=MODE.TRAIN, burn_in_step=100000, increment_step=10000, teacher_rate=1.0, max_step=100): """ Build a model. :param params: dict. {encoder: {rnn_cell: {}, ...}, decoder: {rnn_cell: {}, ...}} for example: {'encoder': {'rnn_cell': {'state_size': 512, 'cell_name': 'BasicLSTMCell', 'num_layers': 2, 'input_keep_prob': 1.0, 'output_keep_prob': 1.0}, 'attention_key_size': attention_size}, 'decoder': {'rnn_cell': {'cell_name': 'BasicLSTMCell', 'state_size': 512, 'num_layers': 1, 'input_keep_prob': 1.0, 'output_keep_prob': 1.0}, 'trg_vocab_size': trg_vocab_size}} :param src_vocab: Vocab of source symbols. :param trg_vocab: Vocab of target symbols. :param source_ids: placeholder :param source_seq_length: placeholder :param target_ids: placeholder :param target_seq_length: placeholder :param beam_size: used in beam inference :param mode: :return: """ if mode != MODE.TRAIN: params = sq.disable_dropout(params) tf.logging.info(json.dumps(params, indent=4)) decoder_params = params['decoder'] # parameters source_ids = source_placeholders['src'] source_seq_length = source_placeholders['src_len'] source_sample_matrix = source_placeholders['src_sample_matrix'] source_word_seq_length = source_placeholders['src_word_len'] target_ids = target_placeholders['trg'] target_seq_length = target_placeholders['trg_len'] # Because source encoder is different to the target feedback, # we construct source_embedding_table manually source_char_embedding_table = sq.LookUpOp(src_vocab.vocab_size, src_vocab.embedding_dim, name='source') source_char_embedded = source_char_embedding_table(source_ids) # encode char to word char_encoder = sq.StackRNNEncoder(params['char_encoder'], params['attention_key_size']['char'], name='char_rnn', mode=mode) # char_encoder_outputs: T_c B F char_encoded_representation = char_encoder.encode(source_char_embedded, source_seq_length) char_encoder_outputs = char_encoded_representation.outputs #dynamical_batch_size = tf.shape(char_encoder_outputs)[1] #space_indices = tf.where(tf.equal(tf.transpose(source_ids), src_vocab.space_id)) ##space_indices = tf.transpose(tf.gather_nd(tf.transpose(space_indices), [[1], [0]])) #space_indices = tf.concat(tf.split(space_indices, 2, axis=1)[::-1], axis=1) #space_indices = tf.transpose(tf.reshape(space_indices, [dynamical_batch_size, -1, 2]), # [1, 0, 2]) ## T_w * B * F #source_embedded = tf.gather_nd(char_encoder_outputs, space_indices) # must be time major char_encoder_outputs = tf.transpose(char_encoder_outputs, perm=(1, 0, 2)) sampled_word_embedded = tf.matmul(source_sample_matrix, char_encoder_outputs) source_embedded = tf.transpose(sampled_word_embedded, perm=(1, 0, 2)) char_attention_keys = char_encoded_representation.attention_keys char_attention_values = char_encoded_representation.attention_values char_attention_length = char_encoded_representation.attention_length encoder = sq.StackBidirectionalRNNEncoder( params['encoder'], params['attention_key_size']['word'], name='stack_rnn', mode=mode) encoded_representation = encoder.encode(source_embedded, source_word_seq_length) attention_keys = encoded_representation.attention_keys attention_values = encoded_representation.attention_values attention_length = encoded_representation.attention_length encoder_final_states_bw = encoded_representation.final_state[-1][-1].h # feedback if mode == MODE.RL: tf.logging.info('BUILDING RL TRAIN FEEDBACK......') dynamical_batch_size = tf.shape(attention_keys)[1] feedback = sq.RLTrainingFeedBack(target_ids, target_seq_length, trg_vocab, dynamical_batch_size, burn_in_step=burn_in_step, increment_step=increment_step, max_step=max_step) elif mode == MODE.TRAIN: tf.logging.info('BUILDING TRAIN FEEDBACK WITH {} TEACHER_RATE' '......'.format(teacher_rate)) feedback = sq.TrainingFeedBack(target_ids, target_seq_length, trg_vocab, teacher_rate, max_step=max_step) elif mode == MODE.EVAL: tf.logging.info('BUILDING EVAL FEEDBACK ......') feedback = sq.TrainingFeedBack(target_ids, target_seq_length, trg_vocab, 0., max_step=max_step) else: tf.logging.info('BUILDING INFER FEEDBACK WITH BEAM_SIZE {}' '......'.format(beam_size)) infer_key_size = attention_keys.get_shape().as_list()[-1] infer_value_size = attention_values.get_shape().as_list()[-1] infer_states_bw_shape = encoder_final_states_bw.get_shape().as_list( )[-1] infer_char_key_size = char_attention_keys.get_shape().as_list()[-1] infer_char_value_size = char_attention_values.get_shape().as_list()[-1] encoder_final_states_bw = tf.reshape( tf.tile(encoder_final_states_bw, [1, beam_size]), [-1, infer_states_bw_shape]) # expand beam if TIME_MAJOR: # batch size should be dynamical dynamical_batch_size = tf.shape(attention_keys)[1] final_key_shape = [ -1, dynamical_batch_size * beam_size, infer_key_size ] final_value_shape = [ -1, dynamical_batch_size * beam_size, infer_value_size ] attention_keys = tf.reshape( (tf.tile(attention_keys, [1, 1, beam_size])), final_key_shape) attention_values = tf.reshape( (tf.tile(attention_values, [1, 1, beam_size])), final_value_shape) final_char_key_shape = [ -1, dynamical_batch_size * beam_size, infer_char_key_size ] final_char_value_shape = [ -1, dynamical_batch_size * beam_size, infer_char_value_size ] char_attention_keys = tf.reshape( (tf.tile(char_attention_keys, [1, 1, beam_size])), final_char_key_shape) char_attention_values = tf.reshape( (tf.tile(char_attention_values, [1, 1, beam_size])), final_char_value_shape) else: dynamical_batch_size = tf.shape(attention_keys)[0] final_key_shape = [ dynamical_batch_size * beam_size, -1, infer_key_size ] final_value_shape = [ dynamical_batch_size * beam_size, -1, infer_value_size ] final_char_key_shape = [ dynamical_batch_size * beam_size, -1, infer_char_key_size ] final_char_value_shape = [ dynamical_batch_size * beam_size, -1, infer_char_value_size ] attention_keys = tf.reshape( (tf.tile(attention_keys, [1, beam_size, 1])), final_key_shape) attention_values = tf.reshape( (tf.tile(attention_values, [1, beam_size, 1])), final_value_shape) char_attention_keys = tf.reshape( (tf.tile(char_attention_keys, [1, beam_size, 1])), final_char_key_shape) char_attention_values = tf.reshape( (tf.tile(char_attention_values, [1, beam_size, 1])), final_char_value_shape) attention_length = tf.reshape( tf.transpose(tf.tile([attention_length], [beam_size, 1])), [-1]) char_attention_length = tf.reshape( tf.transpose(tf.tile([char_attention_length], [beam_size, 1])), [-1]) feedback = sq.BeamFeedBack(trg_vocab, beam_size, dynamical_batch_size, max_step=max_step) encoder_decoder_bridge = EncoderDecoderBridge( encoder_final_states_bw.get_shape().as_list()[-1], decoder_params['rnn_cell']) decoder_state_size = decoder_params['rnn_cell']['state_size'] # attention attention = sq.AvAttention(decoder_state_size, attention_keys, attention_values, attention_length, char_attention_keys, char_attention_values, char_attention_length) context_size = attention.context_size with tf.variable_scope('logits_func'): attention_mix = LinearOp(context_size + feedback.embedding_dim + decoder_state_size, decoder_state_size, name='attention_mix') attention_mix_middle = LinearOp(decoder_state_size, decoder_state_size // 2, name='attention_mix_middle') logits_trans = LinearOp(decoder_state_size // 2, feedback.vocab_size, name='logits_trans') logits_func = lambda _softmax: logits_trans( tf.nn.relu( attention_mix_middle(tf.nn.relu(attention_mix(_softmax))))) # decoder decoder = sq.AttentionRNNDecoder( decoder_params, attention, feedback, logits_func=logits_func, init_state=encoder_decoder_bridge(encoder_final_states_bw), mode=mode) decoder_output, decoder_final_state = sq.dynamic_decode(decoder, swap_memory=True, scope='decoder') # not training if mode == MODE.EVAL or mode == MODE.INFER: return decoder_output, decoder_final_state # bos is added in feedback # so target_ids is predict_ids if not TIME_MAJOR: ground_truth_ids = tf.transpose(target_ids, [1, 0]) else: ground_truth_ids = target_ids # construct the loss if mode == MODE.RL: # Creates a variable to hold the global_step. global_step_tensor = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global_step')[0] rl_time_steps = tf.floordiv( tf.maximum(global_step_tensor - burn_in_step, 0), increment_step) start_rl_step = target_seq_length - rl_time_steps baseline_states = tf.stop_gradient(decoder_output.baseline_states) predict_ids = tf.stop_gradient(decoder_output.predicted_ids) # TODO: bug in tensorflow ground_or_predict_ids = tf.cond(tf.greater(rl_time_steps, 0), lambda: predict_ids, lambda: ground_truth_ids) reward, sequence_length = tf.py_func( func=_py_func, inp=[ground_or_predict_ids, ground_truth_ids, trg_vocab.eos_id], Tout=[tf.float32, tf.int32], name='reward') sequence_length.set_shape((None, )) total_loss_avg, entropy_loss_avg, reward_loss_rmse, reward_predicted \ = rl_sequence_loss( logits=decoder_output.logits, predict_ids=predict_ids, sequence_length=sequence_length, baseline_states=baseline_states, start_rl_step=start_rl_step, reward=reward) return decoder_output, total_loss_avg, entropy_loss_avg, \ reward_loss_rmse, reward_predicted else: total_loss_avg = cross_entropy_sequence_loss( logits=decoder_output.logits, targets=ground_truth_ids, sequence_length=target_seq_length) return decoder_output, total_loss_avg, total_loss_avg, \ tf.to_float(0.), tf.to_float(0.)
def build_attention_model(params, src_vocab, trg_vocab, source_ids, source_seq_length, target_ids, target_seq_length, beam_size=1, mode=MODE.TRAIN, burn_in_step=100000, increment_step=10000, teacher_rate=1.0, max_step=100): """ Build a model. :param params: dict. {encoder: {rnn_cell: {}, ...}, decoder: {rnn_cell: {}, ...}} for example: {'encoder': {'rnn_cell': {'state_size': 512, 'cell_name': 'BasicLSTMCell', 'num_layers': 2, 'input_keep_prob': 1.0, 'output_keep_prob': 1.0}, 'attention_key_size': attention_size}, 'decoder': {'rnn_cell': {'cell_name': 'BasicLSTMCell', 'state_size': 512, 'num_layers': 1, 'input_keep_prob': 1.0, 'output_keep_prob': 1.0}, 'trg_vocab_size': trg_vocab_size}} :param src_vocab: Vocab of source symbols. :param trg_vocab: Vocab of target symbols. :param source_ids: placeholder :param source_seq_length: placeholder :param target_ids: placeholder :param target_seq_length: placeholder :param beam_size: used in beam inference :param mode: :return: """ if mode != MODE.TRAIN: params = sq.disable_dropout(params) tf.logging.info(json.dumps(params, indent=4)) # parameters encoder_params = params['encoder'] decoder_params = params['decoder'] # Because source encoder is different to the target feedback, # we construct source_embedding_table manually source_embedding_table = sq.LookUpOp(src_vocab.vocab_size, src_vocab.embedding_dim, name='source') source_embedded = source_embedding_table(source_ids) encoder = sq.StackBidirectionalRNNEncoder(encoder_params, name='stack_rnn', mode=mode) encoded_representation = encoder.encode(source_embedded, source_seq_length) attention_keys = encoded_representation.attention_keys attention_values = encoded_representation.attention_values attention_length = encoded_representation.attention_length # feedback if mode == MODE.RL: tf.logging.info('BUILDING RL TRAIN FEEDBACK......') dynamical_batch_size = tf.shape(attention_keys)[1] feedback = sq.RLTrainingFeedBack(target_ids, target_seq_length, trg_vocab, dynamical_batch_size, burn_in_step=burn_in_step, increment_step=increment_step, max_step=max_step) elif mode == MODE.TRAIN: tf.logging.info('BUILDING TRAIN FEEDBACK WITH {} TEACHER_RATE' '......'.format(teacher_rate)) feedback = sq.TrainingFeedBack(target_ids, target_seq_length, trg_vocab, teacher_rate, max_step=max_step) elif mode == MODE.EVAL: tf.logging.info('BUILDING EVAL FEEDBACK ......') feedback = sq.TrainingFeedBack(target_ids, target_seq_length, trg_vocab, 0., max_step=max_step) else: tf.logging.info('BUILDING INFER FEEDBACK WITH BEAM_SIZE {}' '......'.format(beam_size)) infer_key_size = attention_keys.get_shape().as_list()[-1] infer_value_size = attention_values.get_shape().as_list()[-1] # expand beam if TIME_MAJOR: # batch size should be dynamical dynamical_batch_size = tf.shape(attention_keys)[1] final_key_shape = [ -1, dynamical_batch_size * beam_size, infer_key_size ] final_value_shape = [ -1, dynamical_batch_size * beam_size, infer_value_size ] attention_keys = tf.reshape( (tf.tile(attention_keys, [1, 1, beam_size])), final_key_shape) attention_values = tf.reshape( (tf.tile(attention_values, [1, 1, beam_size])), final_value_shape) else: dynamical_batch_size = tf.shape(attention_keys)[0] final_key_shape = [ dynamical_batch_size * beam_size, -1, infer_key_size ] final_value_shape = [ dynamical_batch_size * beam_size, -1, infer_value_size ] attention_keys = tf.reshape( (tf.tile(attention_keys, [1, beam_size, 1])), final_key_shape) attention_values = tf.reshape( (tf.tile(attention_values, [1, beam_size, 1])), final_value_shape) attention_length = tf.reshape( tf.transpose(tf.tile([attention_length], [beam_size, 1])), [-1]) feedback = sq.BeamFeedBack(trg_vocab, beam_size, dynamical_batch_size, max_step=max_step) # attention attention = sq.Attention(decoder_params['rnn_cell']['state_size'], attention_keys, attention_values, attention_length) # decoder decoder = sq.AttentionRNNDecoder(decoder_params, attention, feedback, mode=mode) decoder_output, decoder_final_state = sq.dynamic_decode(decoder, swap_memory=True, scope='decoder') # not training if mode == MODE.EVAL or mode == MODE.INFER: return decoder_output, decoder_final_state # bos is added in feedback # so target_ids is predict_ids if not TIME_MAJOR: ground_truth_ids = tf.transpose(target_ids, [1, 0]) else: ground_truth_ids = target_ids # construct the loss if mode == MODE.RL: # Creates a variable to hold the global_step. global_step_tensor = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global_step')[0] rl_time_steps = tf.floordiv( tf.maximum(global_step_tensor - burn_in_step, 0), increment_step) start_rl_step = target_seq_length - rl_time_steps baseline_states = tf.stop_gradient(decoder_output.baseline_states) predict_ids = tf.stop_gradient(decoder_output.predicted_ids) # TODO: bug in tensorflow ground_or_predict_ids = tf.cond(tf.greater(rl_time_steps, 0), lambda: predict_ids, lambda: ground_truth_ids) reward, sequence_length = tf.py_func( func=_py_func, inp=[ground_or_predict_ids, ground_truth_ids, trg_vocab.eos_id], Tout=[tf.float32, tf.int32], name='reward') sequence_length.set_shape((None, )) total_loss_avg, entropy_loss_avg, reward_loss_rmse, reward_predicted \ = rl_sequence_loss( logits=decoder_output.logits, predict_ids=predict_ids, sequence_length=sequence_length, baseline_states=baseline_states, start_rl_step=start_rl_step, reward=reward) return decoder_output, total_loss_avg, entropy_loss_avg, \ reward_loss_rmse, reward_predicted else: total_loss_avg = cross_entropy_sequence_loss( logits=decoder_output.logits, targets=ground_truth_ids, sequence_length=target_seq_length) return decoder_output, total_loss_avg, total_loss_avg, \ tf.to_float(0.), tf.to_float(0.)