コード例 #1
0
    def body(self, features, mode):
        """Body of the model, aka Bert

        Arguments:
            features {dict} -- feature dict,
                keys: input_ids, input_mask, segment_ids
            mode {mode} -- mode

        Returns:
            dict -- features extracted from bert.
                keys: 'seq', 'pooled', 'all', 'embed'

        seq:
            tensor, [batch_size, seq_length, hidden_size]
        pooled:
            tensor, [batch_size, hidden_size]
        all:
            list of tensor, num_hidden_layers * [batch_size, seq_length, hidden_size]
        embed:
            tensor, [batch_size, seq_length, hidden_size]
        """

        config = self.config
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = BertModel(config=config.bert_config,
                          is_training=is_training,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          token_type_ids=segment_ids,
                          use_one_hot_embeddings=config.use_one_hot_embeddings)

        feature_dict = {}
        for logit_type in ['seq', 'pooled', 'all', 'embed', 'embed_table']:
            if logit_type == 'seq':
                # tensor, [batch_size, seq_length, hidden_size]
                feature_dict[logit_type] = model.get_sequence_output()
            elif logit_type == 'pooled':
                # tensor, [batch_size, hidden_size]
                feature_dict[logit_type] = model.get_pooled_output()
            elif logit_type == 'all':
                # list, num_hidden_layers * [batch_size, seq_length, hidden_size]
                feature_dict[logit_type] = model.get_all_encoder_layers()
            elif logit_type == 'embed':
                # for res connection
                feature_dict[logit_type] = model.get_embedding_output()
            elif logit_type == 'embed_table':
                feature_dict[logit_type] = model.get_embedding_table()

        return feature_dict
コード例 #2
0
class BertEncoder(object):
    def __init__(self,
                 config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None):
        self.model = BertModel(config=config,
                               is_training=is_training,
                               input_ids=input_ids,
                               input_mask=input_mask,
                               token_type_ids=token_type_ids)

        self.embeddings_table = self.model.get_embedding_table()

    def encode(self):
        #encoded is => sequence_output` shape = [batch_size, seq_length, hidden_size].
        output = self.model.get_sequence_output()
        states = ()
        for layer in self.model.get_all_encoder_layers():
            states += (tf.reduce_mean(layer, axis=1), )
        return output, states,
コード例 #3
0
ファイル: model.py プロジェクト: xurenlu/chatbot
class ChatModel:
    def __init__(self, chatmodel_config):
        self.chatmodel_config = chatmodel_config
        self.max_x_len = chatmodel_config.max_x_len
        self.max_y_len = chatmodel_config.max_y_len
        self.decode_max_len = chatmodel_config.max_decode_len
        self.vocab = chatmodel_config.vocab
        self.config_file = chatmodel_config.config_file
        self.ckpt_file = chatmodel_config.ckpt_file
        self.beam_width = chatmodel_config.beam_width
        self.dropout_rate = chatmodel_config.dropout_rate
        self.coverage_penalty_weight = chatmodel_config.coverage_penalty_weight
        self.length_penalty_weight = chatmodel_config.length_penalty_weight
        self.x = tf.placeholder(tf.int32,
                                shape=[None, self.max_x_len],
                                name='x')
        self.x_mask = tf.placeholder(tf.int32,
                                     shape=[None, self.max_x_len],
                                     name='x_mask')
        self.x_seg = tf.placeholder(tf.int32,
                                    shape=[None, self.max_x_len],
                                    name='x_seg')
        self.x_len = tf.placeholder(tf.int32, shape=[None], name='x_len')
        self.y = tf.placeholder(tf.int32,
                                shape=[None, self.max_y_len],
                                name='y')
        self.y_len = tf.placeholder(tf.int32, shape=[None], name='y_len')

    def create_model(self):
        self.bert_config = BertConfig.from_json_file(self.config_file)
        self.vocab_size = self.bert_config.vocab_size
        self.hidden_size = self.bert_config.hidden_size
        self.bert_model = BertModel(config=self.bert_config,
                                    input_ids=self.x,
                                    input_mask=self.x_mask,
                                    token_type_ids=self.x_seg,
                                    is_training=True,
                                    use_one_hot_embeddings=False)
        if self.ckpt_file is not None:
            tvars = tf.trainable_variables()
            self.assignment_map, self.initialized_variable_map = modeling.get_assignment_map_from_checkpoint(
                tvars, self.ckpt_file)
        X = self.bert_model.get_sequence_output()
        self.embeddings = self.bert_model.get_embedding_table()
        encoder_output = X[:, 1:, :]
        encoder_state = X[:, 0, :]
        batch_size = tf.shape(self.x)[0]
        start_token = tf.ones([batch_size], dtype=tf.int32) * self.vocab['<S>']
        train_output = tf.concat([tf.expand_dims(start_token, 1), self.y], 1)
        output_emb = tf.nn.embedding_lookup(self.embeddings, train_output)
        output_len = self.y_len
        train_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
            output_emb, output_len, self.embeddings, 0.1)
        input_len = self.x_len - 2
        cell = tf.contrib.rnn.GRUCell(num_units=self.hidden_size)

        def decode(scope):
            with tf.variable_scope(scope):
                attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                    num_units=self.hidden_size,
                    memory=encoder_output,
                    memory_sequence_length=input_len)
                attention_cell = tf.contrib.seq2seq.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=self.hidden_size)
                out_cell = MyOutputProjectionWrapper(attention_cell,
                                                     self.vocab_size,
                                                     self.embeddings,
                                                     reuse=False)
                initial_state = out_cell.zero_state(dtype=tf.float32,
                                                    batch_size=batch_size)
                initial_state = initial_state.clone(cell_state=encoder_state)
                decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell=out_cell,
                    helper=train_helper,
                    initial_state=initial_state)
                t_final_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder=decoder,
                    output_time_major=False,
                    impute_finished=True,
                    maximum_iterations=self.decode_max_len)
            with tf.variable_scope(scope, reuse=True):
                tiled_encoder_output = tf.contrib.seq2seq.tile_batch(
                    encoder_output, multiplier=self.beam_width)
                tiled_encoder_state = tf.contrib.seq2seq.tile_batch(
                    encoder_state, multiplier=self.beam_width)
                tiled_input_len = tf.contrib.seq2seq.tile_batch(
                    input_len, multiplier=self.beam_width)
                attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                    num_units=self.hidden_size,
                    memory=tiled_encoder_output,
                    memory_sequence_length=tiled_input_len)
                attention_cell = tf.contrib.seq2seq.AttentionWrapper(
                    cell=cell,
                    attention_mechanism=attention_mechanism,
                    attention_layer_size=self.hidden_size)
                out_cell = MyOutputProjectionWrapper(attention_cell,
                                                     self.vocab_size,
                                                     self.embeddings,
                                                     reuse=True)
                initial_state = out_cell.zero_state(dtype=tf.float32,
                                                    batch_size=batch_size *
                                                    self.beam_width)
                initial_state = initial_state.clone(
                    cell_state=tiled_encoder_state)
                self.end_token = self.vocab['<T>']
                beamDecoder = tf.contrib.seq2seq.BeamSearchDecoder(
                    cell=out_cell,
                    embedding=self.embeddings,
                    start_tokens=start_token,
                    end_token=self.end_token,
                    initial_state=initial_state,
                    beam_width=self.beam_width,
                    coverage_penalty_weight=self.coverage_penalty_weight,
                    length_penalty_weight=self.length_penalty_weight)
                p_final_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder=beamDecoder,
                    output_time_major=False,
                    maximum_iterations=self.decode_max_len)
            return t_final_output, p_final_output

        t_output, p_output = decode('decode')

        p_output = tf.identity(p_output.predicted_ids[:, :, 0],
                               name='predictions')
        return t_output, p_output

    def loss(self):
        t_output, p_output = self.create_model()
        decode_len = tf.shape(t_output.sample_id)[-1]
        y_target = self.y[:, :decode_len]
        mask_len = tf.maximum(decode_len, self.y_len)
        y_mask = tf.sequence_mask(mask_len, self.max_y_len, dtype=tf.float32)
        y_mask = y_mask[:, :decode_len]
        loss = tf.contrib.seq2seq.sequence_loss(t_output.rnn_output,
                                                y_target,
                                                weights=y_mask)
        p_output_sparse = self._convert_tensor_to_sparse(
            p_output, self.end_token)
        y_output_sparse = self._convert_tensor_to_sparse(
            self.y, self.end_token)
        distance = tf.reduce_sum(
            tf.edit_distance(p_output_sparse, y_output_sparse,
                             normalize=False))
        return loss, distance, p_output, t_output.sample_id

    def _convert_tensor_to_sparse(self, a, end_token):
        indices = tf.where(tf.not_equal(a, 0) & tf.not_equal(a, end_token))
        values = tf.gather_nd(a, indices)
        sparse_a = tf.SparseTensor(indices, values,
                                   tf.shape(a, out_type=tf.int64))
        return sparse_a