Esempio n. 1
0
class Generator(Singleton):

    def _build_keyword_encoder(self):
        """ Encode keyword into a vector."""
        self.keyword = tf.placeholder(
                shape = [_BATCH_SIZE, None, CHAR_VEC_DIM],
                dtype = tf.float32, 
                name = "keyword")
        self.keyword_length = tf.placeholder(
                shape = [_BATCH_SIZE],
                dtype = tf.int32,
                name = "keyword_length")
        _, bi_states = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
                cell_bw = tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
                inputs = self.keyword,
                sequence_length = self.keyword_length,
                dtype = tf.float32, 
                time_major = False,
                scope = "keyword_encoder")
        self.keyword_state = tf.concat(bi_states, axis = 1)
        tf.TensorShape([_BATCH_SIZE, _NUM_UNITS]).\
                assert_same_rank(self.keyword_state.shape)

    def _build_context_encoder(self):
        """ Encode context into a list of vectors. """
        self.context = tf.placeholder(
                shape = [_BATCH_SIZE, None, CHAR_VEC_DIM],
                dtype = tf.float32, 
                name = "context")
        self.context_length = tf.placeholder(
                shape = [_BATCH_SIZE],
                dtype = tf.int32,
                name = "context_length")
        bi_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
                cell_bw = tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
                inputs = self.context,
                sequence_length = self.context_length,
                dtype = tf.float32, 
                time_major = False,
                scope = "context_encoder")
        self.context_outputs = tf.concat(bi_outputs, axis = 2)
        tf.TensorShape([_BATCH_SIZE, None, _NUM_UNITS]).\
                assert_same_rank(self.context_outputs.shape)

    def _build_decoder(self):
        """ Decode keyword and context into a sequence of vectors. """
        attention = tf.contrib.seq2seq.BahdanauAttention(
                num_units = _NUM_UNITS, 
                memory = self.context_outputs,
                memory_sequence_length = self.context_length)
        decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                cell = tf.contrib.rnn.GRUCell(_NUM_UNITS),
                attention_mechanism = attention)
        self.decoder_init_state = decoder_cell.zero_state(
                batch_size = _BATCH_SIZE, dtype = tf.float32).\
                        clone(cell_state = self.keyword_state)
        self.decoder_inputs = tf.placeholder(
                shape = [_BATCH_SIZE, None, CHAR_VEC_DIM],
                dtype = tf.float32, 
                name = "decoder_inputs")
        self.decoder_input_length = tf.placeholder(
                shape = [_BATCH_SIZE],
                dtype = tf.int32,
                name = "decoder_input_length")
        self.decoder_outputs, self.decoder_final_state = tf.nn.dynamic_rnn(
                cell = decoder_cell,
                inputs = self.decoder_inputs,
                sequence_length = self.decoder_input_length,
                initial_state = self.decoder_init_state,
                dtype = tf.float32, 
                time_major = False,
                scope = "training_decoder")
        tf.TensorShape([_BATCH_SIZE, None, _NUM_UNITS]).\
                assert_same_rank(self.decoder_outputs.shape)

    def _build_projector(self):
        """ Project decoder_outputs into character space. """
        softmax_w = tf.Variable(
                tf.random_normal(shape = [_NUM_UNITS, len(self.char_dict)],
                    mean = 0.0, stddev = 0.08), 
                trainable = True)
        softmax_b = tf.Variable(
                tf.random_normal(shape = [len(self.char_dict)],
                    mean = 0.0, stddev = 0.08),
                trainable = True)
        reshaped_outputs = self._reshape_decoder_outputs()
        self.logits = tf.nn.bias_add(
                tf.matmul(reshaped_outputs, softmax_w),
                bias = softmax_b)
        self.probs = tf.nn.softmax(self.logits)

    def _reshape_decoder_outputs(self):
        """ Reshape decoder_outputs into shape [?, _NUM_UNITS]. """
        def concat_output_slices(idx, val):
            output_slice = tf.slice(
                    input_ = self.decoder_outputs,
                    begin = [idx, 0, 0],
                    size = [1, self.decoder_input_length[idx],  _NUM_UNITS])
            return tf.add(idx, 1),\
                    tf.concat([val, tf.squeeze(output_slice, axis = 0)], 
                            axis = 0)
        tf_i = tf.constant(0)
        tf_v = tf.zeros(shape = [0, _NUM_UNITS], dtype = tf.float32)
        _, reshaped_outputs = tf.while_loop(
                cond = lambda i, v: i < _BATCH_SIZE,
                body = concat_output_slices,
                loop_vars = [tf_i, tf_v],
                shape_invariants = [tf.TensorShape([]),
                    tf.TensorShape([None, _NUM_UNITS])])
        tf.TensorShape([None, _NUM_UNITS]).\
                assert_same_rank(reshaped_outputs.shape)
        return reshaped_outputs

    def _build_optimizer(self):
        """ Define cross-entropy loss and minimize it. """
        self.targets = tf.placeholder(
                shape = [None],
                dtype = tf.int32, 
                name = "targets")
        labels = tf.one_hot(self.targets, depth = len(self.char_dict))
        cross_entropy = tf.losses.softmax_cross_entropy(
                onehot_labels = labels,
                logits = self.logits)
        self.loss = tf.reduce_mean(cross_entropy)

        self.learning_rate = tf.clip_by_value(
                tf.multiply(1.6e-5, tf.pow(2.1, self.loss)),
                clip_value_min = 0.0002,
                clip_value_max = 0.02)
        self.opt_step = tf.train.AdamOptimizer(
                learning_rate = self.learning_rate).\
                        minimize(loss = self.loss)

    def _build_graph(self):
        self._build_keyword_encoder()
        self._build_context_encoder()
        self._build_decoder()
        self._build_projector()
        self._build_optimizer()

    def __init__(self):
        self.char_dict = CharDict()
        self.char2vec = Char2Vec()
        self._build_graph()
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        self.saver = tf.train.Saver(tf.global_variables())
        self.trained = False
        
    def _initialize_session(self, session):
        checkpoint = tf.train.get_checkpoint_state(save_dir)
        if not checkpoint or not checkpoint.model_checkpoint_path:
            init_op = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
            session.run(init_op)
        else:
            self.saver.restore(session, checkpoint.model_checkpoint_path)
            self.trained = True

    def _compute_prob_list(self,char,keyword_data,keyword_length,context_data, \
            context_length,current_context, state, session, pron_dict):
        decoder_input, decoder_input_length = \
            self._fill_np_matrix([char])
        encoder_feed_dict = {
            self.keyword : keyword_data,
            self.keyword_length : keyword_length,
            self.context : context_data,
            self.context_length : context_length,
            self.decoder_inputs : decoder_input,
            self.decoder_input_length : decoder_input_length
            }
        if char == start_of_sentence():
            pass
        else:
            encoder_feed_dict[self.decoder_init_state] = state
        probs, state = session.run(
            [self.probs, self.decoder_final_state], 
            feed_dict = encoder_feed_dict)
        prob_list = self._gen_prob_list(probs, current_context, pron_dict)
        return prob_list, state

    def _return_n_most_likely(self,prob_list,number):
        max = 0
        used_index = 0
        char = ''
        score = 1
        while number > 0:
            for j, prob in enumerate(prob_list):
                if max < prob:
                    char = self.char_dict.int2char(j)
                    max = prob
                    score = -math.log(max)
                    used_index = j
            prob_list[used_index] = 0
            max = 0
            number -= 1
        return char, score, used_index

    def generate(self, keywords):
        assert NUM_OF_SENTENCES + 1 == len(keywords)
        pron_dict = PronDict()
        context = start_of_sentence()
        with tf.Session() as session:
            self._initialize_session(session)
            if not self.trained:
                print("Please train the model first! (./train.py -g)")
                sys.exit(1)
            # iterate through all keyword, which means iterate through all four sentences

            # provide a random hint to the first sentence to avoid generating the same thing
            hint = keywords.pop(randrange(len(keywords)))

            first_line = True
            for keyword in keywords:
                if first_line:
                    context += hint
                    first_line = False

                keyword_data, keyword_length = self._fill_np_matrix(
                        [keyword] * _BATCH_SIZE)
                context_data, context_length = self._fill_np_matrix(
                        [context] * _BATCH_SIZE)
                char = start_of_sentence()

                word_count = 0
                state = ''
                while word_count < 7:
                    prob_list, state = self._compute_prob_list(char,keyword_data,keyword_length,\
                        context_data,context_length,context,state,session,pron_dict)
                    
                    # randomly sample BEAM_SIZE number of characters and choose the highest probability
                    # generates different poems when given different keywords
                    if word_count == 0:
                        prob_sums = np.cumsum(prob_list)
                        # the array which store the first char

                        char_array = []
                        score_array = []
                        for i in range(BEAM_SIZE):
                            char_array.append('')
                            score_array.append(1)

                        for i in range(BEAM_SIZE):
                            rand_val = prob_sums[-1] * random()
                            for j, prob_sum in enumerate(prob_sums):
                                if rand_val < prob_sum:
                                    char_array[i] = self.char_dict.int2char(j)
                                    score_array[i] *= -math.log(prob_list[j])
                                    break
                        # because we took the negative log we need the minimum prob
                        min_value = 1000
                        min_index = 0
                        for k in range(len(score_array)):
                            if score_array[k] < min_value:
                                min_index = k
                                min_value = score_array[k]
                        char = char_array[min_index]
                        
                        # generates the same poem for the same keywords
                        '''
                        max_value = prob_list[0]
                        max_index = 0
                        for k in range(len(prob_list)):
                            if prob_list[k] > max_value:
                                max_index = k
                                max_value = prob_list[k]
                        char = self.char_dict.int2char(max_index)
                        '''
                        context += char
                        word_count += 1
                        # end of first word

                    else:
                        # perform beam search for two chars
                        char_array = []
                        second_char_array = []
                        score_array = []

                        for i in range(BEAM_SIZE):
                            char_array.append('')
                            second_char_array.append('')
                            score_array.append(1)
                        
                        max = 0

                        # choose the BEAM_SIZE most possible choices
                        for i in range(BEAM_SIZE):
                            char_array[i], score, used_index = self._return_n_most_likely(prob_list,i+1)
                            score_array[i] *= score
                            # make sure that the same thing is not selected again
                            prob_list[used_index] = 0


                        # choose the most possible choice based on the current choice
                        for i in range(BEAM_SIZE):
                            current_context = context + char_array[i]
                            prob_list, state = self._compute_prob_list(char_array[i],keyword_data,keyword_length,\
                                context_data,context_length,current_context,state,session,pron_dict)
                            second_char_array[i], score, used_index = self._return_n_most_likely(prob_list,1)
                            # randomly sample second array and make sure it does not repeat
                            # random_sample = second_char_array[randrange(len(second_char_array))]
                            random_sample = second_char_array[i]
                            used_chars = set(ch for ch in context)

                            tmp = 2

                            while(random_sample == char_array[i] or random_sample in used_chars):
                                second_char_array[i], score, used_index = self._return_n_most_likely(prob_list,tmp)
                                random_sample = second_char_array[i]
                                tmp += 1
                            score_array[i] *= score

                        # because we took the negative log the minimum score is the best
                        min_value = 1000
                        min_index = 0
                        for i in range(len(score_array)):
                            if score_array[i] < min_value:
                                min_index = i
                                min_value = score_array[i]
                        
                        # adjust so that we prevent using the same character again and again
                        used_chars = set(ch for ch in context)
                        first_char = char_array[min_index]
                        in_loop = 0
                        
                        while first_char in used_chars and in_loop < len(char_array):
                            score_array[min_index] = 1000
                            min_value = 1000
                            for i in range(len(score_array)):
                                # find the minimum in the remaining
                                if score_array[i] < min_value:
                                    min_index = i
                                    min_value = score_array[i]
                            first_char = char_array[min_index]
                            in_loop += 1

                        first_char = char_array[min_index]
                        second_char = second_char_array[min_index]

                        context += first_char
                        context += second_char
                        char = second_char
                        word_count += 2
                # append the <END> label
                context += end_of_sentence()
            # remove the extra hint
            context = context[0] + context[len(hint) + 1:]
        return context[1:].split(end_of_sentence())

    def _gen_prob_list(self, probs, context, pron_dict):
        prob_list = probs.tolist()[0]
        prob_list[0] = 0
        prob_list[-1] = 0
        idx = len(context)
        used_chars = set(ch for ch in context)
        for i in range(1, len(prob_list) - 1):
            ch = self.char_dict.int2char(i)
            # Penalize used characters.
            if ch in used_chars:
                prob_list[i] *= 0.2
            # Penalize rhyming violations.
            if (idx == 15 or idx == 31) and \
                    not pron_dict.co_rhyme(ch, context[7]):
                prob_list[i] *= 0.2
            # Penalize tonal violations.
            if idx > 2 and 2 == idx % 8 and \
                    not pron_dict.counter_tone(context[2], ch):
                prob_list[i] *= 0.4
            if (4 == idx % 8 or 6 == idx % 8) and \
                    not pron_dict.counter_tone(context[idx - 2], ch):
                prob_list[i] *= 0.4
        return prob_list

    def train(self, n_epochs = 6):
        print("Training RNN-based generator ...")
        with tf.Session() as session:
            self._initialize_session(session)
            try:
                for epoch in range(n_epochs):
                    batch_no = 0
                    for keywords, contexts, sentences \
                            in batch_train_data(_BATCH_SIZE):
                        sys.stdout.write("[Seq2Seq Training] epoch = %d, " \
                                "line %d to %d ..." % 
                                (epoch, batch_no * _BATCH_SIZE,
                                (batch_no + 1) * _BATCH_SIZE))
                        sys.stdout.flush()
                        self._train_a_batch(session, epoch,
                                keywords, contexts, sentences)
                        batch_no += 1
                        if 0 == batch_no % 32:
                            self.saver.save(session, _model_path)
                    self.saver.save(session, _model_path)
                print("Training is done.")
            except KeyboardInterrupt:
                print("Training is interrupted.")

    def _train_a_batch(self, session, epoch, keywords, contexts, sentences):
        keyword_data, keyword_length = self._fill_np_matrix(keywords)
        context_data, context_length = self._fill_np_matrix(contexts)
        decoder_inputs, decoder_input_length  = self._fill_np_matrix(
                [start_of_sentence() + sentence[:-1] \
                        for sentence in sentences])
        targets = self._fill_targets(sentences)
        feed_dict = {
                self.keyword : keyword_data,
                self.keyword_length : keyword_length,
                self.context : context_data,
                self.context_length : context_length,
                self.decoder_inputs : decoder_inputs,
                self.decoder_input_length : decoder_input_length,
                self.targets : targets
                }
        loss, learning_rate, _ = session.run(
                [self.loss, self.learning_rate, self.opt_step],
                feed_dict = feed_dict)
        print(" loss =  %f, learning_rate = %f" % (loss, learning_rate))

    def _fill_np_matrix(self, texts):
        max_time = max(map(len, texts))
        matrix = np.zeros([_BATCH_SIZE, max_time, CHAR_VEC_DIM], 
                dtype = np.float32)
        for i in range(_BATCH_SIZE):
            for j in range(max_time):
                matrix[i, j, :] = self.char2vec.get_vect(end_of_sentence())
        for i, text in enumerate(texts):
            matrix[i, : len(text)] = self.char2vec.get_vects(text)
        seq_length = [len(texts[i]) if i < len(texts) else 0 \
                for i in range(_BATCH_SIZE)]
        return matrix, seq_length

    def _fill_targets(self, sentences):
        targets = []
        for sentence in sentences:
            targets.extend(map(self.char_dict.char2int, sentence))
        return targets
Esempio n. 2
0
class Generator(Singleton):
    def _build_keyword_encoder(self):
        """ Encode keyword into a vector."""
        self.keyword = tf.placeholder(shape=[_BATCH_SIZE, None, CHAR_VEC_DIM],
                                      dtype=tf.float32,
                                      name="keyword")
        self.keyword_length = tf.placeholder(shape=[_BATCH_SIZE],
                                             dtype=tf.int32,
                                             name="keyword_length")
        _, bi_states = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
            cell_bw=tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
            inputs=self.keyword,
            sequence_length=self.keyword_length,
            dtype=tf.float32,
            time_major=False,
            scope="keyword_encoder")
        self.keyword_state = tf.concat(bi_states, axis=1)
        tf.TensorShape([_BATCH_SIZE, _NUM_UNITS]).\
                assert_same_rank(self.keyword_state.shape)

    def _build_context_encoder(self):
        """ Encode context into a list of vectors. """
        self.context = tf.placeholder(shape=[_BATCH_SIZE, None, CHAR_VEC_DIM],
                                      dtype=tf.float32,
                                      name="context")
        self.context_length = tf.placeholder(shape=[_BATCH_SIZE],
                                             dtype=tf.int32,
                                             name="context_length")
        bi_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
            cell_bw=tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
            inputs=self.context,
            sequence_length=self.context_length,
            dtype=tf.float32,
            time_major=False,
            scope="context_encoder")
        self.context_outputs = tf.concat(bi_outputs, axis=2)
        tf.TensorShape([_BATCH_SIZE, None, _NUM_UNITS]).\
                assert_same_rank(self.context_outputs.shape)

    def _build_decoder(self):
        """ Decode keyword and context into a sequence of vectors. """
        attention = tf.contrib.seq2seq.BahdanauAttention(
            num_units=_NUM_UNITS,
            memory=self.context_outputs,
            memory_sequence_length=self.context_length)
        decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            cell=tf.contrib.rnn.GRUCell(_NUM_UNITS),
            attention_mechanism=attention)
        '''
        这里对论文进行了一点改动,关键词的隐藏层状态不是作为attension的a_0,
        而是作为了decoder的初始化状态
        '''
        self.decoder_init_state = decoder_cell.zero_state(
                batch_size = _BATCH_SIZE, dtype = tf.float32).\
                        clone(cell_state = self.keyword_state)
        self.decoder_inputs = tf.placeholder(
            shape=[_BATCH_SIZE, None, CHAR_VEC_DIM],
            dtype=tf.float32,
            name="decoder_inputs")
        self.decoder_input_length = tf.placeholder(shape=[_BATCH_SIZE],
                                                   dtype=tf.int32,
                                                   name="decoder_input_length")
        self.decoder_outputs, self.decoder_final_state = tf.nn.dynamic_rnn(
            cell=decoder_cell,
            inputs=self.decoder_inputs,
            sequence_length=self.decoder_input_length,
            initial_state=self.decoder_init_state,
            dtype=tf.float32,
            time_major=False,
            scope="training_decoder")
        tf.TensorShape([_BATCH_SIZE, None, _NUM_UNITS]).\
                assert_same_rank(self.decoder_outputs.shape)

    def _build_projector(self):
        """ Project decoder_outputs into character space. """
        softmax_w = tf.Variable(tf.random_normal(
            shape=[_NUM_UNITS, len(self.char_dict)], mean=0.0, stddev=0.08),
                                trainable=True)
        softmax_b = tf.Variable(tf.random_normal(shape=[len(self.char_dict)],
                                                 mean=0.0,
                                                 stddev=0.08),
                                trainable=True)
        reshaped_outputs = self._reshape_decoder_outputs()
        self.logits = tf.nn.bias_add(tf.matmul(reshaped_outputs, softmax_w),
                                     bias=softmax_b)
        self.probs = tf.nn.softmax(self.logits)

    def _reshape_decoder_outputs(self):
        """ Reshape decoder_outputs into shape [?, _NUM_UNITS]. """
        def concat_output_slices(idx, val):
            output_slice = tf.slice(
                input_=self.decoder_outputs,
                begin=[idx, 0, 0],
                size=[1, self.decoder_input_length[idx], _NUM_UNITS])
            return tf.add(idx, 1),\
                    tf.concat([val, tf.squeeze(output_slice, axis = 0)],
                            axis = 0)

        tf_i = tf.constant(0)
        tf_v = tf.zeros(shape=[0, _NUM_UNITS], dtype=tf.float32)
        _, reshaped_outputs = tf.while_loop(cond=lambda i, v: i < _BATCH_SIZE,
                                            body=concat_output_slices,
                                            loop_vars=[tf_i, tf_v],
                                            shape_invariants=[
                                                tf.TensorShape([]),
                                                tf.TensorShape(
                                                    [None, _NUM_UNITS])
                                            ])
        tf.TensorShape([None, _NUM_UNITS]).\
                assert_same_rank(reshaped_outputs.shape)
        return reshaped_outputs

    def _build_optimizer(self):
        """ Define cross-entropy loss and minimize it. """
        self.targets = tf.placeholder(shape=[None],
                                      dtype=tf.int32,
                                      name="targets")
        labels = tf.one_hot(self.targets, depth=len(self.char_dict))
        cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=labels,
                                                        logits=self.logits)
        self.loss = tf.reduce_mean(cross_entropy)

        self.learning_rate = tf.clip_by_value(tf.multiply(
            1.6e-5, tf.pow(2.1, self.loss)),
                                              clip_value_min=0.0002,
                                              clip_value_max=0.02)
        self.opt_step = tf.train.AdamOptimizer(
                learning_rate = self.learning_rate).\
                        minimize(loss = self.loss)

    def _build_graph(self):
        # 256 1层双向GRU
        self._build_keyword_encoder()
        # 256 1层双向GRU
        self._build_context_encoder()
        # attention 256 1层双向GRU
        self._build_decoder()
        # 512 -> 词空间 1层全连接
        self._build_projector()
        # 交叉熵 adam
        self._build_optimizer()

    def __init__(self):
        self.char_dict = CharDict()
        self.char2vec = Char2Vec()
        self._build_graph()
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        self.saver = tf.train.Saver(tf.global_variables())
        # if not os.path.exists("save/model.meta"):
        # else:
        #     self.saver = tf.train.import_meta_graph("save/model.meta")
        self.trained = False

    def _initialize_session(self, session):
        checkpoint = tf.train.get_checkpoint_state(save_dir)
        if not checkpoint or not checkpoint.model_checkpoint_path:
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            session.run(init_op)
        else:
            self.saver.restore(session, checkpoint.model_checkpoint_path)
            self.trained = True

    def generate(self, keywords):
        assert NUM_OF_SENTENCES == len(keywords)
        pron_dict = PronDict()
        context = start_of_sentence()
        with tf.Session() as session:
            self._initialize_session(session)
            if not self.trained:
                print("Please train the model first! (./train.py -g)")
                sys.exit(1)
            for keyword in keywords:
                keyword_data, keyword_length = self._fill_np_matrix(
                    [keyword] * _BATCH_SIZE)
                context_data, context_length = self._fill_np_matrix(
                    [context] * _BATCH_SIZE)
                char = start_of_sentence()
                for _ in range(7):
                    decoder_input, decoder_input_length = \
                            self._fill_np_matrix([char])
                    encoder_feed_dict = {
                        self.keyword: keyword_data,
                        self.keyword_length: keyword_length,
                        self.context: context_data,
                        self.context_length: context_length,
                        self.decoder_inputs: decoder_input,
                        self.decoder_input_length: decoder_input_length
                    }
                    if char == start_of_sentence():
                        pass
                    else:
                        encoder_feed_dict[self.decoder_init_state] = state
                    probs, state = session.run(
                        [self.probs, self.decoder_final_state],
                        feed_dict=encoder_feed_dict)
                    prob_list = self._gen_prob_list(probs, context, pron_dict)
                    prob_sums = np.cumsum(prob_list)
                    rand_val = prob_sums[-1] * random()
                    for i, prob_sum in enumerate(prob_sums):
                        if rand_val < prob_sum:
                            char = self.char_dict.int2char(i)
                            break
                    context += char
                context += end_of_sentence()
        return context[1:].split(end_of_sentence())

    def _gen_prob_list(self, probs, context, pron_dict):
        prob_list = probs.tolist()[0]
        prob_list[0] = 0
        prob_list[-1] = 0
        idx = len(context)
        used_chars = set(ch for ch in context)
        for i in range(1, len(prob_list) - 1):
            ch = self.char_dict.int2char(i)
            # Penalize used characters.
            if ch in used_chars:
                prob_list[i] *= 0.6
            # Penalize rhyming violations.
            if (idx == 15 or idx == 31) and \
                    not pron_dict.co_rhyme(ch, context[7]):
                prob_list[i] *= 0.2
            # Penalize tonal violations.
            if idx > 2 and 2 == idx % 8 and \
                    not pron_dict.counter_tone(context[2], ch):
                prob_list[i] *= 0.4
            if (4 == idx % 8 or 6 == idx % 8) and \
                    not pron_dict.counter_tone(context[idx - 2], ch):
                prob_list[i] *= 0.4
        return prob_list

    def train(self, n_epochs=6):
        print("Training RNN-based generator ...")
        with tf.Session(config=tf.ConfigProto(
                log_device_placement=True)) as session:
            self._initialize_session(session)
            try:
                for epoch in range(n_epochs):
                    batch_no = 0
                    for keywords, contexts, sentences \
                            in batch_train_data(_BATCH_SIZE):
                        sys.stdout.write("[Seq2Seq Training] epoch = %d, " \
                                "line %d to %d ..." %
                                (epoch, batch_no * _BATCH_SIZE,
                                (batch_no + 1) * _BATCH_SIZE))
                        sys.stdout.flush()
                        self._train_a_batch(session, epoch, keywords, contexts,
                                            sentences)
                        batch_no += 1
                        # if 0 == batch_no % 32:
                        #     with open('save/check_epoch', 'a+') as file:
                        #         file.write('{}-{}\n'.format(epoch, batch_no))
                        #     self.saver.save(session, _model_path)
                    with open('save/check_epoch', 'a+') as file:
                        file.write('{}\n'.format(epoch))
                    self.saver.save(session, _model_path)
                print("Training is done.")
            except KeyboardInterrupt:
                print("Training is interrupted.")

    def _train_a_batch(self, session, epoch, keywords, contexts, sentences):
        # padding
        keyword_data, keyword_length = self._fill_np_matrix(keywords)
        context_data, context_length = self._fill_np_matrix(contexts)
        decoder_inputs, decoder_input_length  = self._fill_np_matrix(
                [start_of_sentence() + sentence[:-1] \
                        for sentence in sentences])
        targets = self._fill_targets(sentences)
        # 对所有占位符进行赋值
        feed_dict = {
            self.keyword: keyword_data,
            self.keyword_length: keyword_length,
            self.context: context_data,
            self.context_length: context_length,
            self.decoder_inputs: decoder_inputs,
            self.decoder_input_length: decoder_input_length,
            self.targets: targets
        }
        loss, learning_rate, _ = session.run(
            [self.loss, self.learning_rate, self.opt_step],
            feed_dict=feed_dict)
        print(" loss =  %f, learning_rate = %f" % (loss, learning_rate))
        with open('save/loss.log', 'a+') as file:
            file.write("{}: {}\n".format(epoch, loss))

    def _fill_np_matrix(self, texts):
        max_time = max(map(len, texts))
        matrix = np.zeros([_BATCH_SIZE, max_time, CHAR_VEC_DIM],
                          dtype=np.float32)
        for i in range(_BATCH_SIZE):
            for j in range(max_time):
                # 用end_of_sentence进行填充
                matrix[i, j, :] = self.char2vec.get_vect(end_of_sentence())
        for i, text in enumerate(texts):
            matrix[i, :len(text)] = self.char2vec.get_vects(text)
        seq_length = [len(texts[i]) if i < len(texts) else 0 \
                for i in range(_BATCH_SIZE)]
        return matrix, seq_length

    def _fill_targets(self, sentences):
        targets = []
        for sentence in sentences:
            targets.extend(map(self.char_dict.char2int, sentence))
        return targets
Esempio n. 3
0
class GenerateTransformerModel(tf.keras.Model):
    def __init__(self, isTrain):
        super(GenerateTransformerModel, self).__init__()

        self.char_dict = CharDict()
        self.char2vec = Char2Vec()
        self.learning_rate = 0.001

        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        self.encoder = Encoder(isTrain)
        self.decoder = Decoder(len(self.char_dict), isTrain)

        self.optimizer = tf.keras.optimizers.Adam(
            learning_rate=self.learning_rate)

        self.checkpoint = tf.train.Checkpoint(encoder=self.encoder,
                                              decoer=self.decoder,
                                              optimizer=self.optimizer)
        self.manager = tf.train.CheckpointManager(self.checkpoint,
                                                  save_dir,
                                                  max_to_keep=3)

    def generate(self, keywords):
        if not tf.train.get_checkpoint_state(save_dir):
            print("Please train the model first! (./train.py -g)")
            sys.exit(1)

        self.checkpoint.restore(self.manager.latest_checkpoint)
        print("Checkpoint is loaded successfully !")
        assert NUM_OF_SENTENCES == len(keywords)
        context = start_of_sentence()
        pron_dict = PronDict()
        for keyword in keywords:
            keyword_data, keyword_length = self._fill_np_matrix([keyword] *
                                                                _BATCH_SIZE)
            context_data, context_length = self._fill_np_matrix([context] *
                                                                _BATCH_SIZE)

            encoder_output = self.encoder(keyword_data, context_data)
            char = start_of_sentence()
            for _ in range(7):
                decoder_input, decoder_input_length = \
                    self._fill_np_matrix([char])
                if char == start_of_sentence():
                    pass
                else:
                    encoder_output = decoder_output
                probs, logits, decoder_output = self.decoder(
                    encoder_output, decoder_input, decoder_input_length)
                prob_list = self._gen_prob_list(probs, context, pron_dict)
                prob_sums = np.cumsum(prob_list)
                rand_val = prob_sums[-1] * random()
                for i, prob_sum in enumerate(prob_sums):
                    if rand_val < prob_sum:
                        char = self.char_dict.int2char(i)
                        break
                context += char
            context += end_of_sentence()

        return context[1:].split(end_of_sentence())

    def _gen_prob_list(self, probs, context, pron_dict):
        prob_list = probs.numpy().tolist()[0]
        prob_list[0] = 0
        prob_list[-1] = 0
        idx = len(context)
        used_chars = set(ch for ch in context)
        for i in range(1, len(prob_list) - 1):
            ch = self.char_dict.int2char(i)
            # Penalize used characters.
            if ch in used_chars:
                prob_list[i] *= 0.6
            # Penalize rhyming violations.
            if (idx == 15 or idx == 31) and \
                    not pron_dict.co_rhyme(ch, context[7]):
                prob_list[i] *= 0.2
            # Penalize tonal violations.
            if idx > 2 and 2 == idx % 8 and \
                    not pron_dict.counter_tone(context[2], ch):
                prob_list[i] *= 0.4
            if (4 == idx % 8 or 6 == idx % 8) and \
                    not pron_dict.counter_tone(context[idx - 2], ch):
                prob_list[i] *= 0.4
        return prob_list

    def train(self, n_epochs):
        print("Training RNN-based generator ...")
        try:
            for epoch in range(n_epochs):
                batch_no = 0
                for keywords, contexts, sentences in batch_train_data(
                        _BATCH_SIZE):
                    sys.stdout.write(
                        "[Seq2Seq Training] epoch = %d, line %d to %d ..." %
                        (epoch, batch_no * _BATCH_SIZE,
                         (batch_no + 1) * _BATCH_SIZE))
                    sys.stdout.flush()
                    self._train_a_batch(keywords, contexts, sentences)
                    batch_no += 1
                    if 0 == batch_no % 32:
                        self.manager.save()
                self.manager.save()
            print("Training is done.")
        except KeyboardInterrupt:
            print("Training is interrupted.")

    def _train_a_batch(self, keywords, contexts, sentences):
        keyword_data, keyword_length = self._fill_np_matrix(keywords)
        context_data, context_length = self._fill_np_matrix(contexts)
        decoder_input, decoder_input_length = self._fill_np_matrix(
            [start_of_sentence() + sentence[:-1] for sentence in sentences])
        targets = self._fill_targets(sentences)

        #sentences is from data_utils --> (sentence, keyword, context)
        #澄潭皎镜石崔巍$ 石   ^
        #万壑千岩暗绿苔$	暗	^澄潭皎镜石崔巍$

        # loss, learning_rate = 0
        with tf.GradientTape() as tape:
            encoder_output = self.encoder(keyword_data, context_data)
            probs, logits, decoder_output = self.decoder(
                encoder_output, decoder_input, decoder_input_length)
            loss = self.loss_func(targets, logits, probs)

            learning_rate = self.learning_rate_func(loss)
            optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

            print(" loss =  %f, learning_rate = %f" % (loss, learning_rate))

        variables = self.encoder.trainable_variables + self.decoder.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))

    def loss_func(self, targets, logits, probs):
        labels = self.label_smoothing(
            tf.one_hot(targets, depth=len(self.char_dict)))
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels,
                                                       logits=logits)
        return tf.reduce_mean(loss)

    def label_smoothing(self, inputs, epsilon=0.1):
        V = inputs.get_shape().as_list()[-1]  # number of channels
        return ((1 - epsilon) * inputs) + (epsilon / V)

    def learning_rate_func(self, loss):
        learning_rate = tf.clip_by_value(tf.multiply(1.6e-5, tf.pow(2.1,
                                                                    loss)),
                                         clip_value_min=0.0002,
                                         clip_value_max=0.02)
        return learning_rate

    def _fill_targets(self, sentences):
        targets = []
        for sentence in sentences:
            targets.extend(map(self.char_dict.char2int, sentence))
        return targets

    def _fill_np_matrix(self, texts):
        max_time = max(map(len, texts))  # the len of keyword
        matrix = np.zeros([_BATCH_SIZE, max_time, CHAR_VEC_DIM],
                          dtype=np.float32)
        for i in range(_BATCH_SIZE):
            for j in range(max_time):
                matrix[i, j, :] = self.char2vec.get_vect(end_of_sentence())
        for i, text in enumerate(texts):
            matrix[i, :len(text)] = self.char2vec.get_vects(text)
        seq_length = [len(texts[i]) if i < len(texts) else 0 \
                      for i in range(_BATCH_SIZE)]
        return matrix, seq_length