Ejemplo n.º 1
0
    def __init__(self,args):
        self.args = args

        logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO, datefmt='%I:%M:%S')

        #with open(os.path.join(self.args.model_dir, 'result.json'), 'r') as f:
        # 获取 绝对路径
        with open(r'G:\Postgraduate\myClassesOfGraduateStudentOne\张华平大数据人工智能课程\大作业\字符级LSTM文本自动生成模型\TangPoemGenerator\poet\output_poem\result.json', 'r') as f:
            result = json.load(f)

        params = result['params']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            self.args.encoding = result['encoding']
        else:
            self.args.encoding = 'utf-8'

        base_path = args.data_dir
        """
        源代码中都是使用os.path.join 形成的路径在windows下面,类似于
        “FileNotFoundError: [Errno 2] No such file or directory: './data/poem\\rhyme_words.txt'”
        这样的错误,
        所以本次试验使用 绝对路径进行替换
        """

        #w2v_file = os.path.join(base_path, "vectors_poem.bin")
        w2v_file=r'G:\Postgraduate\myClassesOfGraduateStudentOne\张华平大数据人工智能课程\大作业\字符级LSTM文本自动生成模型\TangPoemGenerator\poet\data\poem\vectors_poem.bin'

        #将词嵌入加载进来
        self.w2v = Word2Vec(w2v_file)

        #读取押韵词
        RhymeWords.read_rhyme_words(r'G:\Postgraduate\myClassesOfGraduateStudentOne\张华平大数据人工智能课程\大作业\字符级LSTM文本自动生成模型\TangPoemGenerator\poet\data\poem\rhyme_words.txt')

        if args.seed >= 0:
            np.random.seed(args.seed)

        logging.info('best_model: %s\n', best_model)
        best_model=r'G:\Postgraduate\myClassesOfGraduateStudentOne\张华平大数据人工智能课程\大作业\字符级LSTM文本自动生成模型\TangPoemGenerator\poet\output_poem\best_model\model-16312'

        self.sess = tf.Session()
        w2v_vocab_size = len(self.w2v.model.vocab)
        with tf.name_scope('evaluation'):
            self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params)
            saver = tf.train.Saver(name='model_saver')
            saver.restore(self.sess, best_model)
Ejemplo n.º 2
0
    def __init__(self, args):
        self.args = args

        logging.basicConfig(stream=sys.stdout,
                            format='%(asctime)s %(levelname)s:%(message)s',
                            level=logging.INFO,
                            datefmt='%I:%M:%S')

        with open(os.path.join(self.args.model_dir, 'result.json'), 'r') as f:
            result = json.load(f)

        params = result['params']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            self.args.encoding = result['encoding']
        else:
            self.args.encoding = 'utf-8'

        base_path = args.data_dir
        w2v_file = os.path.join(base_path, "vectors_poem.bin")
        self.w2v = Word2Vec(w2v_file)

        RhymeWords.read_rhyme_words(os.path.join(base_path, 'rhyme_words.txt'))

        if args.seed >= 0:
            np.random.seed(args.seed)

        logging.info('best_model: %s\n', best_model)

        #        self.sess = tf.Session()

        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.2
        self.sess = tf.Session(config=config)

        #        self.sess = tf.Session(config=tf.ConfigProto(device_count={'gpu':0}))
        w2v_vocab_size = len(self.w2v.model.vocab)
        with tf.name_scope('evaluation'):
            self.model = CharRNNLM(is_training=False,
                                   w2v_model=self.w2v.model,
                                   vocab_size=w2v_vocab_size,
                                   infer=True,
                                   **params)
            saver = tf.train.Saver(name='model_saver')
            saver.restore(self.sess, best_model)
Ejemplo n.º 3
0
class WritePoem():
    def __init__(self, args):
        self.args = args

        logging.basicConfig(stream=sys.stdout,
                            format='%(asctime)s %(levelname)s:%(message)s',
                            level=logging.INFO,
                            datefmt='%I:%M:%S')

        with open(os.path.join(self.args.model_dir, 'result.json'), 'r') as f:
            result = json.load(f)

        params = result['params']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            self.args.encoding = result['encoding']
        else:
            self.args.encoding = 'utf-8'

        base_path = args.data_dir
        w2v_file = os.path.join(base_path, "vectors_poem.bin")
        self.w2v = Word2Vec(w2v_file)

        RhymeWords.read_rhyme_words(os.path.join(base_path, 'rhyme_words.txt'))

        if args.seed >= 0:
            np.random.seed(args.seed)

        logging.info('best_model: %s\n', best_model)

        self.sess = tf.Session()
        w2v_vocab_size = len(self.w2v.model.vocab)
        with tf.name_scope('evaluation'):
            self.model = CharRNNLM(is_training=False,
                                   w2v_model=self.w2v.model,
                                   vocab_size=w2v_vocab_size,
                                   infer=True,
                                   **params)
            saver = tf.train.Saver(name='model_saver')
            saver.restore(self.sess, best_model)

    def free_verse(self):
        '''
        自由诗
        Returns:

        '''
        sample = self.model.sample_seq(self.sess,
                                       40,
                                       '[',
                                       sample_type=SampleType.weighted_sample)
        if not sample:
            return 'err occar!'

        print('free_verse:', sample)

        idx_end = sample.find(']')
        parts = sample.split('。')
        if len(parts) > 1:
            two_sentence_len = len(parts[0]) + len(parts[1])
            if idx_end < 0 or two_sentence_len < idx_end:
                return sample[1:two_sentence_len + 2]

        return sample[1:idx_end]

    @staticmethod
    def assemble(sample):
        if sample:
            parts = sample.split('。')
            if len(parts) > 1:
                return '{}。{}。'.format(parts[0][1:], parts[1][:len(parts[0])])

        return ''

    def rhyme_verse(self):
        '''
        押韵诗
        Returns:

        '''
        gen_len = 20
        sample = self.model.sample_seq(self.sess,
                                       gen_len,
                                       start_text='[',
                                       sample_type=SampleType.weighted_sample)
        if not sample:
            return 'err occar!'

        print('rhyme_verse:', sample)

        parts = sample.split('。')
        if len(parts) > 0:
            start = parts[0] + '。'
            rhyme_ref_word = start[-2]
            rhyme_seq = len(start) - 3

            sample = self.model.sample_seq(
                self.sess,
                gen_len,
                start,
                sample_type=SampleType.weighted_sample,
                rhyme_ref=rhyme_ref_word,
                rhyme_idx=rhyme_seq)
            print(sample)
            return WritePoem.assemble(sample)

        return sample[1:]

    def hide_words(self, given_text):
        '''
        藏字诗
        Args:
            given_text:

        Returns:

        '''
        if (not given_text):
            return self.rhyme_verse()

        givens = ['', '']
        split_len = math.ceil(len(given_text) / 2)
        givens[0] = given_text[:split_len]
        givens[1] = given_text[split_len:]

        gen_len = 20
        sample = self.model.sample_seq(self.sess,
                                       gen_len,
                                       start_text='[',
                                       sample_type=SampleType.select_given,
                                       given=givens[0])
        if not sample:
            return 'err occar!'

        print('rhyme_verse:', sample)

        parts = sample.split('。')
        if len(parts) > 0:
            start = parts[0] + '。'
            rhyme_ref_word = start[-2]
            rhyme_seq = len(start) - 3
            # gen_len = len(start) - 1

            sample = self.model.sample_seq(self.sess,
                                           gen_len,
                                           start,
                                           sample_type=SampleType.select_given,
                                           given=givens[1],
                                           rhyme_ref=rhyme_ref_word,
                                           rhyme_idx=rhyme_seq)
            print(sample)
            return WritePoem.assemble(sample)

        return sample[1:]

    def cangtou(self, given_text):
        '''
        藏头诗
        Returns:

        '''
        if (not given_text):
            return self.rhyme_verse()

        start = ''
        rhyme_ref_word = ''
        rhyme_seq = 0

        # for i,word in enumerate(given_text):
        for i in range(4):
            word = ''
            if i < len(given_text):
                word = given_text[i]

            if i == 0:
                start = '[' + word
            else:
                start += word

            before_idx = len(start)
            if (i != 3):
                sample = self.model.sample_seq(
                    self.sess,
                    self.args.length,
                    start,
                    sample_type=SampleType.weighted_sample)

            else:
                if not word:
                    rhyme_seq += 1

                sample = self.model.sample_seq(self.sess,
                                               self.args.length,
                                               start,
                                               sample_type=SampleType.max_prob,
                                               rhyme_ref=rhyme_ref_word,
                                               rhyme_idx=rhyme_seq)

            print('Sampled text is:\n\n%s' % sample)

            sample = sample[before_idx:]
            idx1 = sample.find(',')
            idx2 = sample.find('。')
            min_idx = min(idx1, idx2)

            if min_idx == -1:
                if idx1 > -1:
                    min_idx = idx1
                else:
                    min_idx = idx2
            if min_idx > 0:
                # last_sample.append(sample[:min_idx + 1])
                start = '{}{}'.format(start, sample[:min_idx + 1])

                if i == 1:
                    rhyme_seq = min_idx - 1
                    rhyme_ref_word = sample[rhyme_seq]

            print('last_sample text is:\n\n%s' % start)

        return WritePoem.assemble(start)
Ejemplo n.º 4
0
def main(args=''):
    args = config_poem_train(args)
    # Specifying location to store model, best model and tensorboard log.
    args.save_model = os.path.join(args.output_dir, 'save_model/model')
    args.save_best_model = os.path.join(args.output_dir, 'best_model/model')
    # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/')
    timestamp = str(int(time.time()))
    args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp))
    print("Writing to {}\n".format(args.tb_log_dir))

    # Create necessary directories.
    if len(args.init_dir) != 0:
        args.output_dir = args.init_dir
    else:
        if os.path.exists(args.output_dir):
            shutil.rmtree(args.output_dir)
        for paths in [args.save_model, args.save_best_model, args.tb_log_dir]:
            os.makedirs(os.path.dirname(paths))

    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO, datefmt='%I:%M:%S')

    print('=' * 60)
    print('All final and intermediate outputs will be stored in %s/' % args.output_dir)
    print('=' * 60 + '\n')

    logging.info('args are:\n%s', args)

    if len(args.init_dir) != 0:
        with open(os.path.join(args.init_dir, 'result.json'), 'r') as f:
            result = json.load(f)
        params = result['params']
        args.init_model = result['latest_model']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            args.encoding = result['encoding']
        else:
            args.encoding = 'utf-8'

    else:
        params = {'batch_size': args.batch_size,
                  'num_unrollings': args.num_unrollings,
                  'hidden_size': args.hidden_size,
                  'max_grad_norm': args.max_grad_norm,
                  'embedding_size': args.embedding_size,
                  'num_layers': args.num_layers,
                  'learning_rate': args.learning_rate,
                  'cell_type': args.cell_type,
                  'dropout': args.dropout,
                  'input_dropout': args.input_dropout}
        best_model = ''
    logging.info('Parameters are:\n%s\n', json.dumps(params, sort_keys=True, indent=4))

    # Create batch generators.
    batch_size = params['batch_size']
    num_unrollings = params['num_unrollings']

    base_path = args.data_path
    w2v_file = os.path.join(base_path, "vectors_poem.bin")
    w2v = Word2Vec(w2v_file)

    train_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'train')
    test1_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'test')
    valid_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'valid')

    # Create graphs
    logging.info('Creating graph')
    graph = tf.Graph()
    with graph.as_default():
        w2v_vocab_size = len(w2v.model.vocab)
        with tf.name_scope('training'):
            train_model = CharRNNLM(is_training=True,w2v_model = w2v.model,vocab_size=w2v_vocab_size, infer=False, **params)
            tf.get_variable_scope().reuse_variables()

        with tf.name_scope('validation'):
            valid_model = CharRNNLM(is_training=False,w2v_model = w2v.model, vocab_size=w2v_vocab_size, infer=False, **params)

        with tf.name_scope('evaluation'):
            test_model = CharRNNLM(is_training=False,w2v_model = w2v.model,vocab_size=w2v_vocab_size,  infer=False, **params)
            saver = tf.train.Saver(name='model_saver')
            best_model_saver = tf.train.Saver(name='best_model_saver')

    logging.info('Start training\n')

    result = {}
    result['params'] = params


    try:
        with tf.Session(graph=graph) as session:
            # Version 8 changed the api of summary writer to use
            # graph instead of graph_def.
            if TF_VERSION >= 8:
                graph_info = session.graph
            else:
                graph_info = session.graph_def

            train_summary_dir = os.path.join(args.tb_log_dir, "summaries", "train")
            train_writer = tf.summary.FileWriter(train_summary_dir, graph_info)
            valid_summary_dir = os.path.join(args.tb_log_dir, "summaries", "valid")
            valid_writer = tf.summary.FileWriter(valid_summary_dir, graph_info)

            # load a saved model or start from random initialization.
            if len(args.init_model) != 0:
                saver.restore(session, args.init_model)
            else:
                tf.global_variables_initializer().run()

            learning_rate = args.learning_rate
            for epoch in range(args.num_epochs):
                logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', epoch)
                logging.info('Training on training set')
                # training step
                ppl, train_summary_str, global_step = train_model.run_epoch(session, train_data_loader, is_training=True,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
                # record the summary
                train_writer.add_summary(train_summary_str, global_step)
                train_writer.flush()
                # save model
                saved_path = saver.save(session, args.save_model,
                                        global_step=train_model.global_step)

                logging.info('Latest model saved in %s\n', saved_path)
                logging.info('Evaluate on validation set')

                valid_ppl, valid_summary_str, _ = valid_model.run_epoch(session, valid_data_loader, is_training=False,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)

                # save and update best model
                if (len(best_model) == 0) or (valid_ppl < best_valid_ppl):
                    best_model = best_model_saver.save(session, args.save_best_model,
                                                       global_step=train_model.global_step)
                    best_valid_ppl = valid_ppl
                else:
                    learning_rate /= 2.0
                    logging.info('Decay the learning rate: ' + str(learning_rate))

                valid_writer.add_summary(valid_summary_str, global_step)
                valid_writer.flush()

                logging.info('Best model is saved in %s', best_model)
                logging.info('Best validation ppl is %f\n', best_valid_ppl)

                result['latest_model'] = saved_path
                result['best_model'] = best_model
                # Convert to float because numpy.float is not json serializable.
                result['best_valid_ppl'] = float(best_valid_ppl)

                result_path = os.path.join(args.output_dir, 'result.json')
                if os.path.exists(result_path):
                    os.remove(result_path)
                with open(result_path, 'w') as f:
                    json.dump(result, f, indent=2, sort_keys=True)

            logging.info('Latest model is saved in %s', saved_path)
            logging.info('Best model is saved in %s', best_model)
            logging.info('Best validation ppl is %f\n', best_valid_ppl)

            logging.info('Evaluate the best model on test set')
            saver.restore(session, best_model)
            test_ppl, _, _ = test_model.run_epoch(session, test1_data_loader, is_training=False,
                                     learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
            result['test_ppl'] = float(test_ppl)
    except Exception as e:
        print('err :{}'.format(e))
    finally:
        result_path = os.path.join(args.output_dir, 'result.json')
        if os.path.exists(result_path):
            os.remove(result_path)
        with open(result_path, 'w',encoding='utf-8',errors='ignore') as f:
            json.dump(result, f, indent=2, sort_keys=True)
Ejemplo n.º 5
0
def main(args=''):
    args = config_poem_train(args)
    # Specifying location to store model, best model and tensorboard log.
    args.save_model = os.path.join(args.output_dir, 'save_model/model')
    args.save_best_model = os.path.join(args.output_dir, 'best_model/model')
    # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/')
    timestamp = str(int(time.time()))
    args.tb_log_dir = os.path.abspath(
        os.path.join(args.output_dir, "tensorboard_log", timestamp))
    print("Writing to {}\n".format(args.tb_log_dir))

    # Create necessary directories.
    if len(args.init_dir) != 0:
        args.output_dir = args.init_dir
    else:
        if os.path.exists(args.output_dir):
            shutil.rmtree(args.output_dir)
        for paths in [args.save_model, args.save_best_model, args.tb_log_dir]:
            os.makedirs(os.path.dirname(paths))

    logging.basicConfig(stream=sys.stdout,
                        format='%(asctime)s %(levelname)s:%(message)s',
                        level=logging.INFO,
                        datefmt='%I:%M:%S')

    print('=' * 60)
    print('All final and intermediate outputs will be stored in %s/' %
          args.output_dir)
    print('=' * 60 + '\n')

    logging.info('args are:\n%s', args)

    if len(args.init_dir) != 0:
        with open(os.path.join(args.init_dir, 'result.json'), 'r') as f:
            result = json.load(f)
        params = result['params']
        args.init_model = result['latest_model']
        best_model = result['best_model']
        best_valid_ppl = result['best_valid_ppl']
        if 'encoding' in result:
            args.encoding = result['encoding']
        else:
            args.encoding = 'utf-8'

    else:
        params = {
            'batch_size': args.batch_size,
            'num_unrollings': args.num_unrollings,
            'hidden_size': args.hidden_size,
            'max_grad_norm': args.max_grad_norm,
            'embedding_size': args.embedding_size,
            'num_layers': args.num_layers,
            'learning_rate': args.learning_rate,
            'cell_type': args.cell_type,
            'dropout': args.dropout,
            'input_dropout': args.input_dropout
        }
        best_model = ''
    logging.info('Parameters are:\n%s\n',
                 json.dumps(params, sort_keys=True, indent=4))

    # Create batch generators.
    batch_size = params['batch_size']
    num_unrollings = params['num_unrollings']

    base_path = args.data_path
    w2v_file = os.path.join(base_path, "vectors_poem.bin")
    w2v = Word2Vec(w2v_file)

    train_data_loader = DataLoader(base_path, batch_size, num_unrollings,
                                   w2v.model, 'train')
    test1_data_loader = DataLoader(base_path, batch_size, num_unrollings,
                                   w2v.model, 'test')
    valid_data_loader = DataLoader(base_path, batch_size, num_unrollings,
                                   w2v.model, 'valid')

    # Create graphs
    logging.info('Creating graph')
    graph = tf.Graph()
    with graph.as_default():
        w2v_vocab_size = len(w2v.model.vocab)
        with tf.name_scope('training'):
            train_model = CharRNNLM(is_training=True,
                                    w2v_model=w2v.model,
                                    vocab_size=w2v_vocab_size,
                                    infer=False,
                                    **params)
            tf.get_variable_scope().reuse_variables()

        with tf.name_scope('validation'):
            valid_model = CharRNNLM(is_training=False,
                                    w2v_model=w2v.model,
                                    vocab_size=w2v_vocab_size,
                                    infer=False,
                                    **params)

        with tf.name_scope('evaluation'):
            test_model = CharRNNLM(is_training=False,
                                   w2v_model=w2v.model,
                                   vocab_size=w2v_vocab_size,
                                   infer=False,
                                   **params)
            saver = tf.train.Saver(name='model_saver')
            best_model_saver = tf.train.Saver(name='best_model_saver')

    logging.info('Start training\n')

    result = {}
    result['params'] = params

    try:
        with tf.Session(graph=graph) as session:
            # Version 8 changed the api of summary writer to use
            # graph instead of graph_def.
            if TF_VERSION >= 8:
                graph_info = session.graph
            else:
                graph_info = session.graph_def

            train_summary_dir = os.path.join(args.tb_log_dir, "summaries",
                                             "train")
            train_writer = tf.summary.FileWriter(train_summary_dir, graph_info)
            valid_summary_dir = os.path.join(args.tb_log_dir, "summaries",
                                             "valid")
            valid_writer = tf.summary.FileWriter(valid_summary_dir, graph_info)

            # load a saved model or start from random initialization.
            if len(args.init_model) != 0:
                saver.restore(session, args.init_model)
            else:
                tf.global_variables_initializer().run()

            learning_rate = args.learning_rate
            for epoch in range(args.num_epochs):
                logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', epoch)
                logging.info('Training on training set')
                # training step
                ppl, train_summary_str, global_step = train_model.run_epoch(
                    session,
                    train_data_loader,
                    is_training=True,
                    learning_rate=learning_rate,
                    verbose=args.verbose,
                    freq=args.progress_freq)
                # record the summary
                train_writer.add_summary(train_summary_str, global_step)
                train_writer.flush()
                # save model
                saved_path = saver.save(session,
                                        args.save_model,
                                        global_step=train_model.global_step)

                logging.info('Latest model saved in %s\n', saved_path)
                logging.info('Evaluate on validation set')

                valid_ppl, valid_summary_str, _ = valid_model.run_epoch(
                    session,
                    valid_data_loader,
                    is_training=False,
                    learning_rate=learning_rate,
                    verbose=args.verbose,
                    freq=args.progress_freq)

                # save and update best model
                if (len(best_model) == 0) or (valid_ppl < best_valid_ppl):
                    best_model = best_model_saver.save(
                        session,
                        args.save_best_model,
                        global_step=train_model.global_step)
                    best_valid_ppl = valid_ppl
                else:
                    learning_rate /= 2.0
                    logging.info('Decay the learning rate: ' +
                                 str(learning_rate))

                valid_writer.add_summary(valid_summary_str, global_step)
                valid_writer.flush()

                logging.info('Best model is saved in %s', best_model)
                logging.info('Best validation ppl is %f\n', best_valid_ppl)

                result['latest_model'] = saved_path
                result['best_model'] = best_model
                # Convert to float because numpy.float is not json serializable.
                result['best_valid_ppl'] = float(best_valid_ppl)

                result_path = os.path.join(args.output_dir, 'result.json')
                if os.path.exists(result_path):
                    os.remove(result_path)
                with open(result_path, 'w') as f:
                    json.dump(result, f, indent=2, sort_keys=True)

            logging.info('Latest model is saved in %s', saved_path)
            logging.info('Best model is saved in %s', best_model)
            logging.info('Best validation ppl is %f\n', best_valid_ppl)

            logging.info('Evaluate the best model on test set')
            saver.restore(session, best_model)
            test_ppl, _, _ = test_model.run_epoch(session,
                                                  test1_data_loader,
                                                  is_training=False,
                                                  learning_rate=learning_rate,
                                                  verbose=args.verbose,
                                                  freq=args.progress_freq)
            result['test_ppl'] = float(test_ppl)
    except Exception as e:
        print('err :{}'.format(e))
    finally:
        result_path = os.path.join(args.output_dir, 'result.json')
        if os.path.exists(result_path):
            os.remove(result_path)
        with open(result_path, 'w', encoding='utf-8', errors='ignore') as f:
            json.dump(result, f, indent=2, sort_keys=True)