Beispiel #1
0
    def __init__(self,args):
        self.args = args

        logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO, datefmt='%I:%M:%S')

        #with open(os.path.join(self.args.model_dir, 'result.json'), 'r') as f:
        # 获取 绝对路径
        with open(r'G:\Postgraduate\myClassesOfGraduateStudentOne\张华平大数据人工智能课程\大作业\字符级LSTM文本自动生成模型\TangPoemGenerator\poet\output_poem\result.json', 'r') as f:
            result = json.load(f)

        params = result['params']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            self.args.encoding = result['encoding']
        else:
            self.args.encoding = 'utf-8'

        base_path = args.data_dir
        """
        源代码中都是使用os.path.join 形成的路径在windows下面,类似于
        “FileNotFoundError: [Errno 2] No such file or directory: './data/poem\\rhyme_words.txt'”
        这样的错误,
        所以本次试验使用 绝对路径进行替换
        """

        #w2v_file = os.path.join(base_path, "vectors_poem.bin")
        w2v_file=r'G:\Postgraduate\myClassesOfGraduateStudentOne\张华平大数据人工智能课程\大作业\字符级LSTM文本自动生成模型\TangPoemGenerator\poet\data\poem\vectors_poem.bin'

        #将词嵌入加载进来
        self.w2v = Word2Vec(w2v_file)

        #读取押韵词
        RhymeWords.read_rhyme_words(r'G:\Postgraduate\myClassesOfGraduateStudentOne\张华平大数据人工智能课程\大作业\字符级LSTM文本自动生成模型\TangPoemGenerator\poet\data\poem\rhyme_words.txt')

        if args.seed >= 0:
            np.random.seed(args.seed)

        logging.info('best_model: %s\n', best_model)
        best_model=r'G:\Postgraduate\myClassesOfGraduateStudentOne\张华平大数据人工智能课程\大作业\字符级LSTM文本自动生成模型\TangPoemGenerator\poet\output_poem\best_model\model-16312'

        self.sess = tf.Session()
        w2v_vocab_size = len(self.w2v.model.vocab)
        with tf.name_scope('evaluation'):
            self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params)
            saver = tf.train.Saver(name='model_saver')
            saver.restore(self.sess, best_model)
Beispiel #2
0
    def __init__(self, args):
        self.args = args

        logging.basicConfig(stream=sys.stdout,
                            format='%(asctime)s %(levelname)s:%(message)s',
                            level=logging.INFO,
                            datefmt='%I:%M:%S')

        with open(os.path.join(self.args.model_dir, 'result.json'), 'r') as f:
            result = json.load(f)

        params = result['params']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            self.args.encoding = result['encoding']
        else:
            self.args.encoding = 'utf-8'

        base_path = args.data_dir
        w2v_file = os.path.join(base_path, "vectors_poem.bin")
        self.w2v = Word2Vec(w2v_file)

        RhymeWords.read_rhyme_words(os.path.join(base_path, 'rhyme_words.txt'))

        if args.seed >= 0:
            np.random.seed(args.seed)

        logging.info('best_model: %s\n', best_model)

        #        self.sess = tf.Session()

        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.2
        self.sess = tf.Session(config=config)

        #        self.sess = tf.Session(config=tf.ConfigProto(device_count={'gpu':0}))
        w2v_vocab_size = len(self.w2v.model.vocab)
        with tf.name_scope('evaluation'):
            self.model = CharRNNLM(is_training=False,
                                   w2v_model=self.w2v.model,
                                   vocab_size=w2v_vocab_size,
                                   infer=True,
                                   **params)
            saver = tf.train.Saver(name='model_saver')
            saver.restore(self.sess, best_model)
Beispiel #3
0
def main(args=''):
    args = config_poem_train(args)
    # Specifying location to store output_chat, best output_chat and tensorboard log.
    args.save_model = os.path.join(args.output_dir, 'save_model/output_chat')
    args.save_best_model = os.path.join(args.output_dir,
                                        'best_model/output_chat')
    # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/')
    timestamp = str(int(time.time()))
    args.tb_log_dir = os.path.abspath(
        os.path.join(args.output_dir, "tensorboard_log", timestamp))
    print("Writing to {}\n".format(args.tb_log_dir))

    # Create necessary directories.
    if len(args.init_dir) != 0:
        args.output_dir = args.init_dir
    else:
        if os.path.exists(args.output_dir):
            shutil.rmtree(args.output_dir)
        for paths in [args.save_model, args.save_best_model, args.tb_log_dir]:
            os.makedirs(os.path.dirname(paths))

    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO,
                        datefmt='%I:%M:%S')

    print('=' * 60)
    print('All final and intermediate outputs will be stored in %s/' %
          args.output_dir)
    print('=' * 60 + '\n')

    logging.info('args are:\n%s', args)

    if len(args.init_dir) != 0:
        with open(os.path.join(args.init_dir, 'result.json'), 'r') as f:
            result = json.load(f)
        params = result['params']
        args.init_model = result['latest_model']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            args.encoding = result['encoding']
        else:
            args.encoding = 'utf-8'

    else:
        params = {
            'batch_size': args.batch_size,
            'num_unrollings': args.num_unrollings,
            'hidden_size': args.hidden_size,
            'max_grad_norm': args.max_grad_norm,
            'embedding_size': args.embedding_size,
            'num_layers': args.num_layers,
            'learning_rate': args.learning_rate,
            'cell_type': args.cell_type,
            'dropout': args.dropout,
            'input_dropout': args.input_dropout
        }
        best_model = ''
    logging.info('Parameters are:\n%s\n',
                 json.dumps(params, sort_keys=True, indent=4))

    # Create batch generators.
    batch_size = params['batch_size']
    num_unrollings = params['num_unrollings']

    base_path = args.data_path
    w2v_file = os.path.join(base_path, "vectors_poem.bin")
    w2v = Word2Vec(w2v_file)

    train_data_loader = DataLoader(base_path, batch_size, num_unrollings,
                                   w2v.model, 'train')
    test1_data_loader = DataLoader(base_path, batch_size, num_unrollings,
                                   w2v.model, 'test')
    valid_data_loader = DataLoader(base_path, batch_size, num_unrollings,
                                   w2v.model, 'valid')

    # Create graphs
    logging.info('Creating graph')
    graph = tf.Graph()
    with graph.as_default():
        w2v_vocab_size = len(w2v.model.vocab)
        with tf.name_scope('training'):
            train_model = CharRNNLM(is_training=True,
                                    w2v_model=w2v.model,
                                    vocab_size=w2v_vocab_size,
                                    infer=False,
                                    **params)
            tf.get_variable_scope().reuse_variables()

        with tf.name_scope('validation'):
            valid_model = CharRNNLM(is_training=False,
                                    w2v_model=w2v.model,
                                    vocab_size=w2v_vocab_size,
                                    infer=False,
                                    **params)

        with tf.name_scope('evaluation'):
            test_model = CharRNNLM(is_training=False,
                                   w2v_model=w2v.model,
                                   vocab_size=w2v_vocab_size,
                                   infer=False,
                                   **params)
            saver = tf.train.Saver(name='model_saver')
            best_model_saver = tf.train.Saver(name='best_model_saver')

    logging.info('Start training\n')

    result = {}
    result['params'] = params

    try:
        with tf.Session(graph=graph) as session:
            # Version 8 changed the api of summary writer to use
            # graph instead of graph_def.
            if TF_VERSION >= 8:
                graph_info = session.graph
            else:
                graph_info = session.graph_def

            train_summary_dir = os.path.join(args.tb_log_dir, "summaries",
                                             "train")
            train_writer = tf.summary.FileWriter(train_summary_dir, graph_info)
            valid_summary_dir = os.path.join(args.tb_log_dir, "summaries",
                                             "valid")
            valid_writer = tf.summary.FileWriter(valid_summary_dir, graph_info)

            # load a saved output_chat or start from random initialization.
            if len(args.init_model) != 0:
                saver.restore(session, args.init_model)
            else:
                tf.global_variables_initializer().run()

            learning_rate = args.learning_rate
            for epoch in range(args.num_epochs):
                logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', epoch)
                logging.info('Training on training set')
                # training step
                ppl, train_summary_str, global_step = train_model.run_epoch(
                    session,
                    train_data_loader,
                    is_training=True,
                    learning_rate=learning_rate,
                    verbose=args.verbose,
                    freq=args.progress_freq)
                # record the summary
                train_writer.add_summary(train_summary_str, global_step)
                train_writer.flush()
                # save output_chat
                saved_path = saver.save(session,
                                        args.save_model,
                                        global_step=train_model.global_step)

                logging.info('Latest output_chat saved in %s\n', saved_path)
                logging.info('Evaluate on validation set')

                valid_ppl, valid_summary_str, _ = valid_model.run_epoch(
                    session,
                    valid_data_loader,
                    is_training=False,
                    learning_rate=learning_rate,
                    verbose=args.verbose,
                    freq=args.progress_freq)

                # save and update best output_chat
                if (len(best_model) == 0) or (valid_ppl < best_valid_ppl):
                    best_model = best_model_saver.save(
                        session,
                        args.save_best_model,
                        global_step=train_model.global_step)
                    best_valid_ppl = valid_ppl
                else:
                    learning_rate /= 2.0
                    logging.info('Decay the learning rate: ' +
                                 str(learning_rate))

                valid_writer.add_summary(valid_summary_str, global_step)
                valid_writer.flush()

                logging.info('Best output_chat is saved in %s', best_model)
                logging.info('Best validation ppl is %f\n', best_valid_ppl)

                result['latest_model'] = saved_path
                result['best_model'] = best_model
                # Convert to float because numpy.float is not json serializable.
                result['best_valid_ppl'] = float(best_valid_ppl)

                result_path = os.path.join(args.output_dir, 'result.json')
                if os.path.exists(result_path):
                    os.remove(result_path)
                with open(result_path, 'w') as f:
                    json.dump(result, f, indent=2, sort_keys=True)

            logging.info('Latest output_chat is saved in %s', saved_path)
            logging.info('Best output_chat is saved in %s', best_model)
            logging.info('Best validation ppl is %f\n', best_valid_ppl)

            logging.info('Evaluate the best output_chat on test set')
            saver.restore(session, best_model)
            test_ppl, _, _ = test_model.run_epoch(session,
                                                  test1_data_loader,
                                                  is_training=False,
                                                  learning_rate=learning_rate,
                                                  verbose=args.verbose,
                                                  freq=args.progress_freq)
            result['test_ppl'] = float(test_ppl)
    except Exception as e:
        print('err :{}'.format(e))
    finally:
        result_path = os.path.join(args.output_dir, 'result.json')
        if os.path.exists(result_path):
            os.remove(result_path)
        with open(result_path, 'w', encoding='utf-8', errors='ignore') as f:
            json.dump(result, f, indent=2, sort_keys=True)
Beispiel #4
0
                if (w not in vocab):
                    w = '<unknown>'
                line.append(vocab[w])
            text_array.append(line)
            # else :
            #     print w,'not exist'

        return text_array

if __name__ == '__main__':
    base_path = './data/poem'
    # poem = '风急云轻鹤背寒,洞天谁道却归难。千山万水瀛洲路,何处烟飞是醮坛。是的'
    # idx = poem.rfind('。')
    # poem_part = poem[:idx + 1]
    w2v_file = os.path.join(base_path, "vectors_poem.bin")
    w2v = Word2Vec(w2v_file)

    # vect = w2v_model['['][:10]
    # print(vect)
    #
    # vect = w2v_model['春'][:10]
    # print(vect)

    in_file = os.path.join(base_path, 'poems_edge.txt')
    # fr = open(in_file, "r",encoding='utf-8')
    # poems = fr.readlines()
    # fr.close()
    #
    #
    #
    # print("唐诗总数: %d"%len(poems))