示例#1
0
    def __init__(self, gpu_number=0):
        # Load model config
        config = load_config(FLAGS)

        config_proto = tf.ConfigProto(
            # allow_soft_placement=FLAGS.allow_soft_placement,
            # log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(
                allow_growth=True)  #, visible_device_list=str(gpu_number))
        )
        self.graphpre = tf.Graph()
        self.sess = tf.Session(graph=self.graphpre, config=config_proto)

        with self.sess.as_default():
            with self.graphpre.as_default():

                # Build the model
                self.model = Seq2SeqModel(config, 'predict')

                # Create saver
                # Using var_list = None returns the list of all saveable variables
                saver = tf.train.Saver(var_list=None)

                # Reload existing checkpoint
                load_model(self.sess, self.model, saver)
                self.planner = Planner()

                print("poetry is ok!")
示例#2
0
def main(_):

    planner = Planner()
    keywords = planner.plan(input)
    with Seq2SeqPredictor() as predictor:
        lines = predictor.predict(KEYWORDS)
        for line in lines:
            print(line)
 def __init__(self):
     self.char_dict = CharDict()
     self.char2vec = Char2Vec()
     # Edited:
     self.planner = Planner()
     self._build_graph()
     if not os.path.exists(save_dir):
         os.mkdir(save_dir)
     self.saver = tf.train.Saver(tf.global_variables())
     self.trained = False
示例#4
0
文件: app.py 项目: Abingcbc/MuTao
def generate():
    text = request.args['keywords']
    planner = Planner()
    generator = Generator()
    keywords = planner.plan(text)
    # print("Keywords: " + ' '.join(keywords))
    poem = generator.generate(keywords)
    # print("Poem generated:")
    # for sentence in poem:
    #     print(sentence)
    return '\n'.join(poem)
def generate_rnn_samples(sampled_poems):
    planner = Planner()
    with Seq2SeqPredictor() as predictor:
        with open(rnn_samples_path, 'w+') as fout:
            for poem_idx, poem in enumerate(sampled_poems):
                input = string.join(poem).strip()
                keywords = planner.plan(input)

                print 'Predicting poem {}.'.format(poem_idx)
                lines = predictor.predict(keywords)

                for idx, sentence in enumerate(lines):
                    punctuation = u'\uff0c' if idx % 2 == 0 else u'\u3002'
                    line = (sentence + punctuation + '\n').encode('utf-8')
                    fout.write(line)
示例#6
0
def eval_generated_data(num=100):
    evaluator = RhymeEvaluator()

    planner = Planner()
    predictor = Seq2SeqPredictor()

    poems = []
    for _ in range(num):
        keywords = planner.plan(u'')
        assert 4 == len(keywords)

        sentences = predictor.predictor(keywords)
        poems.append(sentences)

    print("Testing {} quatrains generated by model.".format(num))
    eval_poems(evaluator, poems)
示例#7
0
def eval_generated_data(num=4000):
    evaluator = RhymeEvaluator()

    planner = Planner()
    predictor = Seq2SeqPredictor()

    poems = []
    sentences = []
    i = 1
    with open("./data/samples/default.txt") as f:
        for text in f.readlines():

            sentences.append(text.strip()[:-1])
            if i % 4 == 0:
                poems.append(sentences)
                sentences = []
            i += 1
    # for _ in range(num):
    #     keywords = planner.plan(u'')
    #     assert 4 == len(keywords)
    #
    #     sentences = predictor.predict(keywords)
    #     poems.append(sentences)

    print("Testing {} quatrains generated by model.".format(num))
    eval_poems(evaluator, poems)
示例#8
0
def main(args, cangtou=False):
    planner = Planner()
    with Seq2SeqPredictor() as predictor:
        # Run loop
        terminate = False
        Judge = MatchUtil()
        while not terminate:
            try:
                input = args.Input.decode('utf-8').strip()
                if not input:
                    print 'Input cannot be empty!'
                elif input.lower() in ['quit', 'exit']:
                    terminate = True
                else:
                    if cangtou:
                        keywords = get_cangtou_keywords(input)
                    else:
                        # Generate keywords
                        keywords = planner.plan(input)

                    # Generate poem
                    lines = predictor.predict(keywords)

                    # whether the couplet is in accordance with the rules
                    result = Judge.eval_rhyme(lines)

                    if result == True:
                        # Print keywords and poem
                        print 'Keyword:\t\tPoem:'
                        for line_number in xrange(2):
                            punctuation = u',' if line_number % 2 == 0 else u'。'
                            print u'{keyword}\t\t{line}{punctuation}'.format(
                                keyword=keywords[line_number],
                                line=lines[line_number],
                                punctuation=punctuation)
                            terminate = True

            except EOFError:
                terminate = True
            except KeyboardInterrupt:
                terminate = True
    print '\nTerminated.'
示例#9
0
def main(args, cangtou=False):

    planner = Planner()
    with Seq2SeqPredictor() as predictor:
        # Run loop
        terminate = False
        while not terminate:
            try:
                input = raw_input('Input Text:\n').decode('utf-8').strip()

                if not input:
                    print 'Input cannot be empty!'
                elif input.lower() in ['quit', 'exit']:
                    terminate = True
                else:
                    if cangtou:
                        keywords = get_cangtou_keywords(input)
                    else:
                        # Generate keywords
                        keywords = planner.plan(input)

                    # Generate poem
                    if args.nocouplet:
                        lines = predictor.predict(keywords)
                    else:
                        lines = predictor.predict_with_couplet(keywords)

                    # Print keywords and poem
                    print 'Keyword:\t\tPoem:'
                    for line_number in xrange(4):
                        punctuation = u',' if line_number % 2 == 0 else u'。'
                        print u'{keyword}\t\t{line}{punctuation}'.format(
                            keyword=keywords[line_number],
                            line=lines[line_number],
                            punctuation=punctuation)

            except EOFError:
                terminate = True
            except KeyboardInterrupt:
                terminate = True
    print '\nTerminated.'
示例#10
0
def main():
    """ The entry point of red agent """
    listener_port = constant.DEFAULT_LISTENER_PORT

    #TODO: Command line arguments should be handled elegantly.
    if len(sys.argv) > 1:
        listener_port = sys.argv[1]
    AttackDB().update()

    Listener().work(constant.DEFAULT_LISTENER_IP, int(listener_port))

    assumed_state = {
        constant.STATE_KEY_PROTOCOL: constant.STATE_VALUE_PROTOCOL_IP,
        constant.STATE_KEY_ADDRESS: constant.DEFAULT_TARGET_IP
    }

    goal = Condition()
    goal.add((constant.STATE_KEY_SHELL, constant.CONDITION_OPERATOR_EQUAL,
              constant.STATE_VALUE_SHELL_PERMANENT))
    goal.add((constant.STATE_KEY_PRIVILEGE, constant.CONDITION_OPERATOR_EQUAL,
              constant.STATE_VALUE_PRIVILEGE_ROOT))

    planner = Planner(assumed_state, goal)
    try:
        planner.make_plan()
        print(planner)
        planner.run()

    except Exception as e:
        traceback.print_exc()
        Listener().stop()
        raise e

    Listener().stop()
    sys.exit(0)
示例#11
0
def main(cangtou=False):
    planner = Planner()
    with Seq2SeqPredictor() as predictor:
        # Run loop
        terminate = False
        while not terminate:
            try:
                import ipdb
                ipdb.set_trace()
                input = input('Input Text:\n').decode('utf-8').strip()

                if not input:
                    print('Input cannot be empty!')
                elif input.lower() in ['quit', 'exit']:
                    terminate = True
                else:
                    if cangtou:
                        keywords = get_cangtou_keywords(input)
                    else:
                        # Generate keywords
                        keywords = planner.plan(input)

                    # Generate poem
                    lines = predictor.predict(keywords)

                    # Print keywords and poem
                    print('Keyword:\t\tPoem:')
                    for line_number in range(4):
                        punctuation = ',' if line_number % 2 == 0 else '。'
                        print('{keyword}\t\t{line}{punctuation}'.format(
                            keyword=keywords[line_number],
                            line=lines[line_number],
                            punctuation=punctuation))

            except EOFError:
                terminate = True
            except KeyboardInterrupt:
                terminate = True
    print('\nTerminated.')
示例#12
0
def main(cangtou=False):
    planner = Planner()
    with Seq2SeqPredictor() as predictor:
        # Run loop
        terminate = False
        while not terminate:
            try:
                inputs =input('Input Text:\n').strip()

                if not inputs:
                    print( 'Input cannot be empty!')
                elif inputs.lower() in ['quit', 'exit']:
                    terminate = True
                else:
                    if cangtou:
                        keywords = get_cangtou_keywords(inputs)
                    else:
                        # Generate keywords
                        #将输入的句子切词并按照textrank 值进行降序排列,并且选择前四个词作为keyword
                        keywords = planner.plan(inputs)
                        print(keywords)
                    # Generate poem
                    lines = predictor.predict(keywords)
                    # Print keywords and poem
                    print( 'Keyword:\t\tPoem:')
                    for line_number in range(4):
                        punctuation = u',' if line_number % 2 == 0 else u'。'
                        print (u'{keyword}\t\t{line}{punctuation}'.format(
                            keyword=keywords[line_number],
                            line=lines[line_number],
                            punctuation=punctuation
                        ))

            except EOFError:
                terminate = True
            except KeyboardInterrupt:
                terminate = True
    print ('\nTerminated.')
示例#13
0
class Main_Poetry_maker:
    def __init__(self):
        self.planner = Planner()
        self.predictor = Seq2SeqPredictor()
        self.Judge = MatchUtil()

    def predict(self, input_ustr):
        input_ustr = input_ustr.strip()
        keywords = self.planner.plan(input_ustr)
        lines = self.predictor.predict(keywords)
        result = self.Judge.eval_rhyme(lines)
        while (result == False):
            lines = self.predictor.predict(keywords)
            result = self.Judge.eval_rhyme(lines)
        return lines[0] + '   ' + lines[1]
示例#14
0
class Main_Poetry_maker:
    def __init__(self):
        self.planner = Planner()
        self.predictor = Seq2SeqPredictor()
        self.Judge = MatchUtil()

    def predict(self, input_ustr):
        input_ustr = input_ustr.strip()
        keywords = self.planner.plan(input_ustr)
        lines = self.predictor.predict(keywords)
        result = self.Judge.eval_rhyme(lines)
        while (result == False):
            lines = self.predictor.predict(keywords)
            result = self.Judge.eval_rhyme(lines)

        # for line_number in range(2):
        #     punctuation = u',' if line_number % 2 == 0 else u'。'
        #     (u'{keyword}\t\t{line}{punctuation}'.format(
        #             keyword=keywords[line_number],
        #             line=lines[line_number],
        #             punctuation=punctuation
        #     ))
        return ','.join(lines)
示例#15
0
    scores = []
    with codecs.open('results.txt', 'r', 'utf-8') as fin:
        line = fin.readline()
        while line:
            scores.append(evaluator.eval(split_sentences(line.strip())))
            line = fin.readline()
    print "Mean score = %f, standard deviation = %f" % (np.mean(scores), np.std(scores))
    '''
    quatrains = get_quatrains()
    print "Testing %d quatrains from the corpus." % len(quatrains)
    scores = []
    for quatrain in quatrains:
        score = evaluator.eval(quatrain['sentences'])
        scores.append(score)
    print "Mean score = %f, standard deviation = %f" % (np.mean(scores),
                                                        np.std(scores))
    num = 100
    print "Testing %d poems generated by RNN ..." % num
    scores = []
    planner = Planner()
    generator = Generator()
    for _ in range(num):
        keywords = planner.plan(u'')
        assert 4 == len(keywords)
        sentences = generator.generate(keywords)
        score = evaluator.eval(sentences)
        print "score = %f" % score
        scores.append(score)
    print "Mean score = %f, standard deviation = %f" % (np.mean(scores),
                                                        np.std(scores))
class Generator(Singleton):

    def _build_keyword_encoder(self):
        """ Encode keyword into a vector."""
        self.keyword = tf.placeholder(
                shape = [_BATCH_SIZE, None, CHAR_VEC_DIM],
                dtype = tf.float32, 
                name = "keyword")
        self.keyword_length = tf.placeholder(
                shape = [_BATCH_SIZE],
                dtype = tf.int32,
                name = "keyword_length")
        _, bi_states = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
                cell_bw = tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
                inputs = self.keyword,
                sequence_length = self.keyword_length,
                dtype = tf.float32, 
                time_major = False,
                scope = "keyword_encoder")
        self.keyword_state = tf.concat(bi_states, axis = 1)
        tf.TensorShape([_BATCH_SIZE, _NUM_UNITS]).\
                assert_same_rank(self.keyword_state.shape)

    def _build_context_encoder(self):
        """ Encode context into a list of vectors. """
        self.context = tf.placeholder(
                shape = [_BATCH_SIZE, None, CHARPIN_VEC_DIM],
                dtype = tf.float32, 
                name = "context")
        self.context_length = tf.placeholder(
                shape = [_BATCH_SIZE],
                dtype = tf.int32,
                name = "context_length")
        bi_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
                cell_bw = tf.contrib.rnn.GRUCell(_NUM_UNITS / 2),
                inputs = self.context,
                sequence_length = self.context_length,
                dtype = tf.float32, 
                time_major = False,
                scope = "context_encoder")  # [batch_size, max_time, output_size] for output_fw, output_bw
        self.context_outputs = tf.concat(bi_outputs, axis = 2)
        tf.TensorShape([_BATCH_SIZE, None, _NUM_UNITS]).\
                assert_same_rank(self.context_outputs.shape)

    def _build_decoder(self):
        """ Decode keyword and context into a sequence of vectors. """
        attention = tf.contrib.seq2seq.BahdanauAttention(
                num_units = _NUM_UNITS,  # the depth of the query
                memory = self.context_outputs, # the output of encoder
                memory_sequence_length = self.context_length)
        decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                cell = tf.contrib.rnn.GRUCell(_NUM_UNITS),  # Note bidirectional only for encoder
                attention_mechanism = attention)
        self.decoder_init_state = decoder_cell.zero_state(
                batch_size = _BATCH_SIZE, dtype = tf.float32).\
                        clone(cell_state = self.keyword_state)
        self.decoder_inputs = tf.placeholder(
                shape = [_BATCH_SIZE, None, CHARPIN_VEC_DIM],
                dtype = tf.float32, 
                name = "decoder_inputs")
        self.decoder_input_length = tf.placeholder(
                shape = [_BATCH_SIZE],
                dtype = tf.int32,
                name = "decoder_input_length")
        self.decoder_outputs, self.decoder_final_state = tf.nn.dynamic_rnn(
                cell = decoder_cell,
                inputs = self.decoder_inputs,
                sequence_length = self.decoder_input_length,
                initial_state = self.decoder_init_state,
                dtype = tf.float32, 
                time_major = False,
                scope = "training_decoder")
        tf.TensorShape([_BATCH_SIZE, None, _NUM_UNITS]).\
                assert_same_rank(self.decoder_outputs.shape)

    def _build_projector(self):
        """ Project decoder_outputs into character space. """
        softmax_w = tf.Variable(
                tf.random_normal(shape = [_NUM_UNITS, len(self.char_dict)],
                    mean = 0.0, stddev = 0.08), 
                trainable = True)
        softmax_b = tf.Variable(
                tf.random_normal(shape = [len(self.char_dict)],
                    mean = 0.0, stddev = 0.08),
                trainable = True)
        reshaped_outputs = self._reshape_decoder_outputs()
        self.logits = tf.nn.bias_add(
                tf.matmul(reshaped_outputs, softmax_w),
                bias = softmax_b)
        self.probs = tf.nn.softmax(self.logits)

    def _reshape_decoder_outputs(self):
        """ Reshape decoder_outputs into shape [?, _NUM_UNITS]. """
        def concat_output_slices(idx, val):
            output_slice = tf.slice(
                    input_ = self.decoder_outputs,
                    begin = [idx, 0, 0],
                    size = [1, self.decoder_input_length[idx],  _NUM_UNITS])
            return tf.add(idx, 1),\
                    tf.concat([val, tf.squeeze(output_slice, axis = 0)], 
                            axis = 0)
        tf_i = tf.constant(0)
        tf_v = tf.zeros(shape = [0, _NUM_UNITS], dtype = tf.float32)
        _, reshaped_outputs = tf.while_loop(
                cond = lambda i, v: i < _BATCH_SIZE,
                body = concat_output_slices,
                loop_vars = [tf_i, tf_v],
                shape_invariants = [tf.TensorShape([]),
                    tf.TensorShape([None, _NUM_UNITS])])
        tf.TensorShape([None, _NUM_UNITS]).\
                assert_same_rank(reshaped_outputs.shape)
        return reshaped_outputs

    def _build_optimizer(self):
        """ Define cross-entropy loss and minimize it. """
        self.targets = tf.placeholder(
                shape = [None],
                dtype = tf.int32, 
                name = "targets")
        labels = tf.one_hot(self.targets, depth = len(self.char_dict))
        cross_entropy = tf.losses.softmax_cross_entropy(
                onehot_labels = labels,
                logits = self.logits)
        self.loss = tf.reduce_mean(cross_entropy)

        self.learning_rate = tf.clip_by_value(
                tf.multiply(1.6e-5, tf.pow(2.1, self.loss)),
                clip_value_min = 0.0002,
                clip_value_max = 0.02)
        self.opt_step = tf.train.AdamOptimizer(
                learning_rate = self.learning_rate).\
                        minimize(loss = self.loss)

    def _build_graph(self):
        self._build_keyword_encoder()
        self._build_context_encoder()
        self._build_decoder()
        self._build_projector()
        self._build_optimizer()

    def __init__(self):
        self.char_dict = CharDict()
        self.char2vec = Char2Vec()
        # Edited:
        self.planner = Planner()
        self._build_graph()
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        self.saver = tf.train.Saver(tf.global_variables())
        self.trained = False
        
    def _initialize_session(self):
        checkpoint = tf.train.get_checkpoint_state(save_dir)
        if not checkpoint or not checkpoint.model_checkpoint_path:
            init_op = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
            sess.run(init_op)
        else:
            self.saver.restore(sess, checkpoint.model_checkpoint_path)
            self.trained = True

    def generate(self, keywords):
        assert NUM_OF_SENTENCES == len(keywords)
        pron_dict = PronDict()
        context = start_of_sentence()
        with tf.Session() as session:
            self._initialize_session()
            if not self.trained:
                print("Please train the model first! (./train.py -g)")
                sys.exit(1)
            for keyword in keywords:
                keyword_data, keyword_length = self._fill_np_matrix(
                        [keyword] * _BATCH_SIZE, True)
                context_data, context_length = self._fill_np_matrix(
                        [context] * _BATCH_SIZE, False)
                char = start_of_sentence()
                for _ in range(7):
                    decoder_input, decoder_input_length = \
                            self._fill_np_matrix([char], False)
                    encoder_feed_dict = {
                            self.keyword : keyword_data,
                            self.keyword_length : keyword_length,
                            self.context : context_data,
                            self.context_length : context_length,
                            self.decoder_inputs : decoder_input,
                            self.decoder_input_length : decoder_input_length
                            }
                    if char == start_of_sentence():
                        pass
                    else:
                        encoder_feed_dict[self.decoder_init_state] = state
                    probs, state = sess.run(
                            [self.probs, self.decoder_final_state], 
                            feed_dict = encoder_feed_dict)
                    prob_list = self._gen_prob_list(probs, context, pron_dict)
                    prob_sums = np.cumsum(prob_list)
                    rand_val = prob_sums[-1] * random()
                    for i, prob_sum in enumerate(prob_sums):
                        if rand_val < prob_sum:
                            char = self.char_dict.int2char(i)
                            break
                    context += char
                context += end_of_sentence()
        return context[1:].split(end_of_sentence())

    def _gen_prob_list(self, probs, context, pron_dict):
        prob_list = probs.tolist()[0]
        prob_list[0] = 0
        prob_list[-1] = 0
        idx = len(context)
        used_chars = set(ch for ch in context)
        for i in range(1, len(prob_list) - 1):
            ch = self.char_dict.int2char(i)
            # Penalize used characters.
            if ch in used_chars:
                prob_list[i] *= 0.6
            # Penalize rhyming violations.
            if (idx == 15 or idx == 31) and \
                    not pron_dict.co_rhyme(ch, context[7]):
                prob_list[i] *= 0.2
            # Penalize tonal violations.
            if idx > 2 and 2 == idx % 8 and \
                    not pron_dict.counter_tone(context[2], ch):
                prob_list[i] *= 0.4
            if (4 == idx % 8 or 6 == idx % 8) and \
                    not pron_dict.counter_tone(context[idx - 2], ch):
                prob_list[i] *= 0.4
        return prob_list

    def train(self, n_epochs = 6):
        print("Training RNN-based generator ...")
        with tf.Session() as session:
            self._initialize_session()
            try:
                for epoch in range(n_epochs):
                    batch_no = 0
                    for keywords, contexts, sentences \
                            in batch_train_data(_BATCH_SIZE):
                        sys.stdout.write("[Seq2Seq Training] epoch = %d, " \
                                "line %d to %d ..." % 
                                (epoch, batch_no * _BATCH_SIZE,
                                (batch_no + 1) * _BATCH_SIZE))
                        sys.stdout.flush()
                        self._train_a_batch(sess, epoch,
                                keywords, contexts, sentences)
                        batch_no += 1
                        if 0 == batch_no % 32:
                            self.saver.save(sess, _model_path)
                    self.saver.save(sess, _model_path)
                print("Training is done.")
            except KeyboardInterrupt:
                print("Training is interrupted.")

    def _train_a_batch(self, session, epoch, keywords, contexts, sentences):
        keyword_data, keyword_length = self._fill_np_matrix(keywords, True)
        context_data, context_length = self._fill_np_matrix(contexts, False)
        decoder_inputs, decoder_input_length  = self._fill_np_matrix(
                [start_of_sentence() + sentence[:-1] \
                        for sentence in sentences], False)
        targets = self._fill_targets(sentences)
        feed_dict = {
                self.keyword : keyword_data,
                self.keyword_length : keyword_length,
                self.context : context_data,
                self.context_length : context_length,
                self.decoder_inputs : decoder_inputs,
                self.decoder_input_length : decoder_input_length,
                self.targets : targets
                }
        loss, learning_rate, _ = session.run(
                [self.loss, self.learning_rate, self.opt_step],
                feed_dict = feed_dict)
        print(" loss =  %f, learning_rate = %f" % (loss, learning_rate))

    def _fill_np_matrix(self, texts, keyword = True):
        
        max_time = max(map(len, texts))
        # Edited
        if keyword: # one keyword each, 64 keywords in total for 16 poems
            matrix = np.zeros([_BATCH_SIZE, max_time, CHAR_VEC_DIM], 
                    dtype = np.float32)
            for i in range(_BATCH_SIZE):
                for j in range(max_time):
                    matrix[i, j, :] = self.planner.get_vect(end_of_sentence())
            for i, text in enumerate(texts):
                matrix[i, : len(text)] = self.planner.get_vects(text)
            seq_length = [len(texts[i]) if i < len(texts) else 0 \
                    for i in range(_BATCH_SIZE)]
        else:
            matrix = np.zeros([_BATCH_SIZE, max_time, CHARPIN_VEC_DIM], 
                    dtype = np.float32)
            for i in range(_BATCH_SIZE):
                for j in range(max_time):
                    matrix[i, j, :] = self.char2vec.get_vect(end_of_sentence())
            for i, text in enumerate(texts):
                matrix[i, : len(text)] = self.char2vec.get_vects(text)
            seq_length = [len(texts[i]) if i < len(texts) else 0 \
                    for i in range(_BATCH_SIZE)]
        return matrix, seq_length

    def _fill_targets(self, sentences):
        targets = []
        for sentence in sentences:
            targets.extend(map(self.char_dict.char2int, sentence))
        return targets
示例#17
0
def init():
    global planner, generator2
    # load the pre-trained Keras model
    planner = Planner()
    generator2 = Generator()
示例#18
0
 def __init__(self):
     self.planner = Planner()
     self.predictor = Seq2SeqPredictor()
     self.Judge = MatchUtil()
示例#19
0
class Seq2SeqPredictor:
    def __init__(self, gpu_number=0):
        # Load model config
        config = load_config(FLAGS)

        config_proto = tf.ConfigProto(
            # allow_soft_placement=FLAGS.allow_soft_placement,
            # log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(
                allow_growth=True)  #, visible_device_list=str(gpu_number))
        )
        self.graphpre = tf.Graph()
        self.sess = tf.Session(graph=self.graphpre, config=config_proto)

        with self.sess.as_default():
            with self.graphpre.as_default():

                # Build the model
                self.model = Seq2SeqModel(config, 'predict')

                # Create saver
                # Using var_list = None returns the list of all saveable variables
                saver = tf.train.Saver(var_list=None)

                # Reload existing checkpoint
                load_model(self.sess, self.model, saver)
                self.planner = Planner()

                print("poetry is ok!")

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.sess.close()

    def Normalize(self, list):  #add by zjg
        string = ''
        for i in list:
            string = string + i + ' '
        keywords = self.planner.plan(string)
        return keywords

    def predict(self, keywords):
        sentences = []
        keywords = self.Normalize(keywords)  # add by zjg
        for keyword in keywords:
            source, source_len = prepare_batch_predict_data(
                keyword,
                previous=sentences,
                prev=FLAGS.prev_data,
                rev=FLAGS.rev_data,
                align=FLAGS.align_data)
            with self.sess.as_default():
                with self.graphpre.as_default():
                    predicted_batch = self.model.predict(
                        self.sess,
                        encoder_inputs=source,
                        encoder_inputs_length=source_len)

            predicted_line = predicted_batch[
                0]  # predicted is a batch of one line
            predicted_line_clean = predicted_line[:-1]  # remove the end token
            predicted_ints = map(
                lambda x: x[0], predicted_line_clean
            )  # Flatten from [time_step, 1] to [time_step]
            predicted_sentence = ints_to_sentence(predicted_ints)

            if FLAGS.rev_data:
                predicted_sentence = predicted_sentence[::-1]

            sentences.append(predicted_sentence)
        return sentences
示例#20
0
#! /usr/bin/env python
# -*- coding:utf-8 -*-

from data_utils import *
from plan import Planner
from generate import Generator
import sys

reload(sys)
sys.setdefaultencoding('utf8')

if __name__ == '__main__':
    planner = Planner()
    generator = Generator()
    while True:
        line = raw_input('Input Text:\t').decode('utf-8').strip()
        if line.lower() == 'quit' or line.lower() == 'exit':
            break
        elif len(line) > 0:
            keywords = planner.plan(line)
            print "Keywords:\t",
            for word in keywords:
                print word,
            print '\n'
            print "Poem Generated:\n"
            sentences = generator.generate(keywords)
            print '\t' + sentences[0] + u',' + sentences[1] + u'。'
            print '\t' + sentences[2] + u',' + sentences[3] + u'。'
            print
示例#21
0
文件: main.py 项目: Abingcbc/MuTao
#! /usr/bin/env python3
# -*- coding: utf-8 -*-

from generate import Generator
from plan import Planner


if __name__ == '__main__':
    planner = Planner()
    generator = Generator()
    while True:
        hints = input("Type in hints >> ")
        keywords = planner.plan(hints)
        print("Keywords: " + ' '.join(keywords))
        poem = generator.generate(keywords)
        print("Poem generated:")
        for sentence in poem:
            print(sentence)