def body(self, features, mode): """Body of the model, aka Bert Arguments: features {dict} -- feature dict, keys: input_ids, input_mask, segment_ids mode {mode} -- mode Returns: dict -- features extracted from bert. keys: 'seq', 'pooled', 'all', 'embed' seq: tensor, [batch_size, seq_length, hidden_size] pooled: tensor, [batch_size, hidden_size] all: list of tensor, num_hidden_layers * [batch_size, seq_length, hidden_size] embed: tensor, [batch_size, seq_length, hidden_size] """ config = self.config input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = BertModel(config=config.bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=config.use_one_hot_embeddings) feature_dict = {} for logit_type in ['seq', 'pooled', 'all', 'embed', 'embed_table']: if logit_type == 'seq': # tensor, [batch_size, seq_length, hidden_size] feature_dict[logit_type] = model.get_sequence_output() elif logit_type == 'pooled': # tensor, [batch_size, hidden_size] feature_dict[logit_type] = model.get_pooled_output() elif logit_type == 'all': # list, num_hidden_layers * [batch_size, seq_length, hidden_size] feature_dict[logit_type] = model.get_all_encoder_layers() elif logit_type == 'embed': # for res connection feature_dict[logit_type] = model.get_embedding_output() elif logit_type == 'embed_table': feature_dict[logit_type] = model.get_embedding_table() return feature_dict
class BertEncoder(object): def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None): self.model = BertModel(config=config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) self.embeddings_table = self.model.get_embedding_table() def encode(self): #encoded is => sequence_output` shape = [batch_size, seq_length, hidden_size]. output = self.model.get_sequence_output() states = () for layer in self.model.get_all_encoder_layers(): states += (tf.reduce_mean(layer, axis=1), ) return output, states,
class ChatModel: def __init__(self, chatmodel_config): self.chatmodel_config = chatmodel_config self.max_x_len = chatmodel_config.max_x_len self.max_y_len = chatmodel_config.max_y_len self.decode_max_len = chatmodel_config.max_decode_len self.vocab = chatmodel_config.vocab self.config_file = chatmodel_config.config_file self.ckpt_file = chatmodel_config.ckpt_file self.beam_width = chatmodel_config.beam_width self.dropout_rate = chatmodel_config.dropout_rate self.coverage_penalty_weight = chatmodel_config.coverage_penalty_weight self.length_penalty_weight = chatmodel_config.length_penalty_weight self.x = tf.placeholder(tf.int32, shape=[None, self.max_x_len], name='x') self.x_mask = tf.placeholder(tf.int32, shape=[None, self.max_x_len], name='x_mask') self.x_seg = tf.placeholder(tf.int32, shape=[None, self.max_x_len], name='x_seg') self.x_len = tf.placeholder(tf.int32, shape=[None], name='x_len') self.y = tf.placeholder(tf.int32, shape=[None, self.max_y_len], name='y') self.y_len = tf.placeholder(tf.int32, shape=[None], name='y_len') def create_model(self): self.bert_config = BertConfig.from_json_file(self.config_file) self.vocab_size = self.bert_config.vocab_size self.hidden_size = self.bert_config.hidden_size self.bert_model = BertModel(config=self.bert_config, input_ids=self.x, input_mask=self.x_mask, token_type_ids=self.x_seg, is_training=True, use_one_hot_embeddings=False) if self.ckpt_file is not None: tvars = tf.trainable_variables() self.assignment_map, self.initialized_variable_map = modeling.get_assignment_map_from_checkpoint( tvars, self.ckpt_file) X = self.bert_model.get_sequence_output() self.embeddings = self.bert_model.get_embedding_table() encoder_output = X[:, 1:, :] encoder_state = X[:, 0, :] batch_size = tf.shape(self.x)[0] start_token = tf.ones([batch_size], dtype=tf.int32) * self.vocab['<S>'] train_output = tf.concat([tf.expand_dims(start_token, 1), self.y], 1) output_emb = tf.nn.embedding_lookup(self.embeddings, train_output) output_len = self.y_len train_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( output_emb, output_len, self.embeddings, 0.1) input_len = self.x_len - 2 cell = tf.contrib.rnn.GRUCell(num_units=self.hidden_size) def decode(scope): with tf.variable_scope(scope): attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( num_units=self.hidden_size, memory=encoder_output, memory_sequence_length=input_len) attention_cell = tf.contrib.seq2seq.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=self.hidden_size) out_cell = MyOutputProjectionWrapper(attention_cell, self.vocab_size, self.embeddings, reuse=False) initial_state = out_cell.zero_state(dtype=tf.float32, batch_size=batch_size) initial_state = initial_state.clone(cell_state=encoder_state) decoder = tf.contrib.seq2seq.BasicDecoder( cell=out_cell, helper=train_helper, initial_state=initial_state) t_final_output, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, output_time_major=False, impute_finished=True, maximum_iterations=self.decode_max_len) with tf.variable_scope(scope, reuse=True): tiled_encoder_output = tf.contrib.seq2seq.tile_batch( encoder_output, multiplier=self.beam_width) tiled_encoder_state = tf.contrib.seq2seq.tile_batch( encoder_state, multiplier=self.beam_width) tiled_input_len = tf.contrib.seq2seq.tile_batch( input_len, multiplier=self.beam_width) attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( num_units=self.hidden_size, memory=tiled_encoder_output, memory_sequence_length=tiled_input_len) attention_cell = tf.contrib.seq2seq.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, attention_layer_size=self.hidden_size) out_cell = MyOutputProjectionWrapper(attention_cell, self.vocab_size, self.embeddings, reuse=True) initial_state = out_cell.zero_state(dtype=tf.float32, batch_size=batch_size * self.beam_width) initial_state = initial_state.clone( cell_state=tiled_encoder_state) self.end_token = self.vocab['<T>'] beamDecoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=out_cell, embedding=self.embeddings, start_tokens=start_token, end_token=self.end_token, initial_state=initial_state, beam_width=self.beam_width, coverage_penalty_weight=self.coverage_penalty_weight, length_penalty_weight=self.length_penalty_weight) p_final_output, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=beamDecoder, output_time_major=False, maximum_iterations=self.decode_max_len) return t_final_output, p_final_output t_output, p_output = decode('decode') p_output = tf.identity(p_output.predicted_ids[:, :, 0], name='predictions') return t_output, p_output def loss(self): t_output, p_output = self.create_model() decode_len = tf.shape(t_output.sample_id)[-1] y_target = self.y[:, :decode_len] mask_len = tf.maximum(decode_len, self.y_len) y_mask = tf.sequence_mask(mask_len, self.max_y_len, dtype=tf.float32) y_mask = y_mask[:, :decode_len] loss = tf.contrib.seq2seq.sequence_loss(t_output.rnn_output, y_target, weights=y_mask) p_output_sparse = self._convert_tensor_to_sparse( p_output, self.end_token) y_output_sparse = self._convert_tensor_to_sparse( self.y, self.end_token) distance = tf.reduce_sum( tf.edit_distance(p_output_sparse, y_output_sparse, normalize=False)) return loss, distance, p_output, t_output.sample_id def _convert_tensor_to_sparse(self, a, end_token): indices = tf.where(tf.not_equal(a, 0) & tf.not_equal(a, end_token)) values = tf.gather_nd(a, indices) sparse_a = tf.SparseTensor(indices, values, tf.shape(a, out_type=tf.int64)) return sparse_a