コード例 #1
0
def main(_):
    FLAGS.start_string = FLAGS.start_string  #.decode('utf-8')
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path =\
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model = CharRNN(converter.vocab_size,
                    sampling=True,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)

    start_string = FLAGS.start_string
    sys.stdout.write("> ")
    sys.stdout.flush()
    start_string = sys.stdin.readline()
    while start_string:
        start = converter.text_to_arr(start_string)
        arr = model.sample(FLAGS.max_length, start, converter.vocab_size)
        print(converter.arr_to_text(arr))

        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
コード例 #2
0
def main(_):
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path = tf.train.latest_checkpoint(
            FLAGS.checkpoint_path)

    model = CharRNN(converter.vocab_size,
                    None,
                    sampling=True,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)

    # start = converter.text_to_arr(FLAGS.seed_for_generating)
    seeds = [
        'var a = fun', 'function a(', 'this.', 'document.', 'window.',
        'var a = document.g', 'var a;', 'jQuery'
    ]
    for seed in seeds:
        start = converter.text_to_arr(seed)
        for i in range(0, FLAGS.num_to_generate):
            print('Generating: ' + seed + ' -> ' + str(i))
            file_name = str(uuid.uuid1())
            file_path = '../../BrowserFuzzingData/generated/' + FLAGS.file_type + '/' + file_name + '.' + FLAGS.file_type
            arr = model.sample(FLAGS.max_length_of_generated, start,
                               converter.vocab_size, converter.word_to_int)
            f = open(file_path, "wb")
            f.write(converter.arr_to_text(arr).encode('utf-8'))
            f.close()
コード例 #3
0
ファイル: test.py プロジェクト: zoulala/Seq2seq_couplet
def main(_):

    model_path = os.path.join('models', Config.file_name)

    converter = TextConverter(vocab_dir='data/vocabs',
                              max_vocab=Config.vocab_size,
                              seq_length=Config.seq_length)
    print('vocab lens:', converter.vocab_size)

    # 加载上一次保存的模型
    model = Model(Config)
    checkpoint_path = tf.train.latest_checkpoint(model_path)
    if checkpoint_path:
        model.load(checkpoint_path)

    while True:

        english_speek = input("上联:")
        english_speek = ' '.join(english_speek)
        english_speek = english_speek.split()
        en_arr, arr_len = converter.text_en_to_arr(english_speek)

        test_g = [np.array([
            en_arr,
        ]), np.array([
            arr_len,
        ])]
        output_ids = model.test(test_g, model_path, converter)
        strs = converter.arr_to_text(output_ids)
        print('下联:', strs)
コード例 #4
0
def main(_):

    model_path = os.path.join('models', Config.file_name)

    et = TextConverter(text=None,save_dir='models/en_vocab.pkl', max_vocab=Config.en_vocab_size, seq_length = Config.seq_length)
    zt = TextConverter(text=None,save_dir='models/zh_vocab.pkl', max_vocab=Config.zh_vocab_size, seq_length = Config.seq_length+1)  # +1是因为,decoder层序列拆成input=[:-1]和label=[1:]
    print('english vocab lens:',et.vocab_size)
    print('chinese vocab lens:',zt.vocab_size)


    # 加载上一次保存的模型
    model = Model(Config)
    checkpoint_path = tf.train.latest_checkpoint(model_path)
    if checkpoint_path:
        model.load(checkpoint_path)

    while True:
        # english_speek = 'what can i help you ?'
        # print('english:', english_speek)
        english_speek = input("english:")

        english_speek = english_speek.split()
        en_arr, arr_len = et.text_to_arr(english_speek)

        test_g = [np.array([en_arr,]), np.array([arr_len,])]
        output_ids = model.test(test_g, model_path, zt)
        strs = zt.arr_to_text(output_ids)
        print('chinese:',strs)
コード例 #5
0
ファイル: gen.py プロジェクト: gajanlee/tiny-flow
def main(_):
    converter = TextConverter(filename=FLAGS.converter_path)

    model = charRNN(converter.vocab_size, train=False)
    model.load(tf.train.latest_checkpoint(FLAGS.checkpoint_path))

    start = converter.text_to_arr(FLAGS.start_string)
    arr = model.generate(FLAGS.max_length, start, converter.vocab_size)
    print(converter.arr_to_text(arr))
コード例 #6
0
def main(_):
    FLAGS.start_string = FLAGS.start_string.decode('utf-8')
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path = \
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model = CharRNN(converter.vocab_size,
                    sampling=True,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)

    start = converter.text_to_arr(FLAGS.start_string)
    arr = model.sample(FLAGS.max_length, start, converter.vocab_size)
    print converter.arr_to_text(arr)
コード例 #7
0
 def test_vocab_size(self):
     testConverter = TextConverter(text=[
         "We", "are", "accounted", "poor", "citizens,", "the", "patricians",
         "goodare", "accounted", "poor", "citizens,", "the", "patricians",
         "good"
     ],
                                   max_vocab=10)
     print(testConverter.vocab_size)
     print(testConverter.int_to_word(4))
     print(testConverter.text_to_arr(['the']))
     print(testConverter.arr_to_text([3, 4]))
コード例 #8
0
def sample():

    with tf.Session() as sess:
        model_path = os.path.join(FLAGS.train_dir, FLAGS.model_name)
        converter = TextConverter(None, FLAGS.max_vocab_size,
                                  os.path.join(model_path, 'converter.pkl'))
        model = create_model(sess, converter.vocab_size, True, model_path)

        sys.stdout.write("> ")
        sys.stdout.flush()
        start_str = sys.stdin.readline().decode('utf-8')
        while start_str:
            start = converter.text_to_arr(start_str)

            samples = [c for c in start]
            initial_state = sess.run(model.initial_state)
            x = np.zeros((1, 1))
            for c in start:
                x[0, 0] = c
                feed = {model.inputs: x, model.initial_state: initial_state}
                preds, final_state = sess.run(
                    [model.proba_prediction, model.final_state],
                    feed_dict=feed)
                initial_state = final_state

            c = pick_top_n(preds, converter.vocab_size)
            while c == converter.vocab_size - 1:
                c = pick_top_n(preds, converter.vocab_size)
            samples.append(c)

            for i in range(FLAGS.sample_length):
                x[0, 0] = c
                feed = {model.inputs: x, model.initial_state: initial_state}
                preds, final_state = sess.run(
                    [model.proba_prediction, model.final_state],
                    feed_dict=feed)
                initial_state = final_state
                c = pick_top_n(preds, converter.vocab_size)
                while c == converter.vocab_size - 1:
                    c = pick_top_n(preds, converter.vocab_size)
                samples.append(c)

            print(converter.arr_to_text(np.array(samples)))

            sys.stdout.write("> ")
            sys.stdout.flush()
            start_str = sys.stdin.readline().decode('utf-8')
コード例 #9
0
ファイル: sample.py プロジェクト: sophistcxf/ThirdLibTest
def main(_):
    FLAGS.start_string = FLAGS.start_string.decode('utf-8')
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path =\
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model = CharRNN(converter.vocab_size, sampling=True,
                    lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)

    start = converter.text_to_arr(FLAGS.start_string)
    arr = model.sample(FLAGS.max_length, start, converter.vocab_size)
    print(converter.arr_to_text(arr))
コード例 #10
0
def main(_):
    FLAGS.start_string = FLAGS.start_string.decode('utf-8')
    converter = TextConverter(filename=FLAGS.converter_path)  #创建文本转化器
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path = tf.train.latest_checkpoint(
            FLAGS.checkpoint_path)  #下载最新模型

    model = CharRNN(converter.vocab_size,
                    sampling=True,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)  #加载模型

    start = converter.text_to_arr(FLAGS.start_string)  #将input text转为id
    arr = model.sample(FLAGS.max_length, start,
                       converter.vocab_size)  #输出为生成的序列
    print(converter.arr_to_text(arr))
コード例 #11
0
def generate():
    tf.compat.v1.disable_eager_execution()
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path =\
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model = CharRNN(converter.vocab_size,
                    sampling=True,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)

    start = converter.text_to_arr(FLAGS.start_string)
    arr = model.sample(FLAGS.max_length, start, converter.vocab_size)

    return converter.arr_to_text(arr)
コード例 #12
0
ファイル: sample.py プロジェクト: yeah529/dianpin-smallapp
class Dianpin(Singleton):
    def __init__(self):
        self.text = ''
        self.tfmodel = None
        self.converter = None

    def model_built(self):#,vocab_size,sampling,lstm_size,num_layers,use_embedding,embedding_size):
        FLAGS.start_string = FLAGS.start_string.decode('utf-8')
        self.converter = TextConverter(filename=FLAGS.converter_path)
        if os.path.isdir(FLAGS.checkpoint_path):
            FLAGS.checkpoint_path =\
                tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        self.tfmodel = CharRNN(self.converter.vocab_size, sampling=True,
                    lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
        self.tfmodel.load(FLAGS.checkpoint_path)
        
    def final_predict(self):
        start = self.converter.text_to_arr(FLAGS.start_string)
        arr = self.tfmodel.sample(FLAGS.max_length, start, self.converter.vocab_size)
        return self.converter.arr_to_text(arr)
コード例 #13
0
def main(_):
    FLAGS.start_string = FLAGS.start_string
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path =\
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model = CharRNN(converter.vocab_size, sampling=True,
                    lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)

    start = converter.text_to_arr(FLAGS.start_string)
    arr = model.predict(FLAGS.max_length, start, converter.vocab_size, 10)
    for c, p in arr:
        prediction = converter.arr_to_text(c)
        prediction = remove_return(prediction)

        # 如果有中文字生成,请将 {1:^14} 改为 {1:{4}^14} 以修复对齐问题。
        # {1:^14}中的 14 随着生成的字符数量而定,一般可以设为字符数+4

        print("{0} -> {1:^14} {2} {3}".format(FLAGS.start_string, prediction, "probability:", p, chr(12288)))
コード例 #14
0
ファイル: sample.py プロジェクト: 336655asd/AI-ARTIST
def poem_genetate(poem_start=u'君'):
    #FLAGS.start_string = FLAGS.start_string
    #FLAGS.start_string = FLAGS.start_string.decode('utf-8')
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path =tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        print FLAGS.checkpoint_path
    """
    model = CharRNN(converter.vocab_size, sampling=True,
                    lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
                    """
    model = CharRNN(converter.vocab_size, sampling=True,
                    lstm_size=lstm_size, num_layers=num_layers,
                    use_embedding=use_embedding,embedding_size=FLAGS.embedding_size)

    model.load(FLAGS.checkpoint_path)

    #start = converter.text_to_arr(start_string)
    start1 = converter.text_to_arr(poem_start)
    arr = model.sample(max_length, start1, converter.vocab_size)
    #pl = model.poemline(max_length, start, converter.vocab_size)
    #sp=model.sample_hide_poetry( start, converter.vocab_size)
    poem=converter.arr_to_text(arr)
    #print (converter.arr_to_text(sp))
    print('---------')
    print(poem)
    print('---------')
    #print(converter.arr_to_text(pl))
    print('---------')
    #0:, 1:。 2:\n,每行12个字符。不可以有0,1,2大于1个
    
    lines=poem.split('\n')
    r_poem=[]
    for i in range(len(lines)):
        if len(lines[i])==12:
            count=0
            print lines[i][5]
            if lines[i][5]==',':
                print "true"
            if lines[i][5]==u',':
                print "u true"
            if lines[i][5]==u',' and lines[i][11]==u'。':
                for j in range(len(lines[i])):
                    if lines[i][j]==u',' or lines[i][j]==u'。':
                        count+=1
                if count==2:
                    r_poem.append(lines[i])
        if len(r_poem)==2:
            break

    """
    lines=poem.split('\n')
    r_poem=[]
    for i in range(len(lines)):
        if len(lines[i])==12:
            count=0
            if lines[i][5]==0 and lines[i][11]==1:
                for j in range(len(lines[i])):
                    if lines[i][j]==0 or lines[i][j]==1:
                        count+=1
                if count==2:
                    r_poem.append(lines[i])
        if len(r_poem)==2:
            break
            """
    with codecs.open("app/poem.txt","w",'utf-8') as f:
        words="".join(r_poem)
        print (lines)
        print (r_poem)
        print (words)
    
        #words=words.decode('utf-8')
        f.write(words)