Пример #1
0
def convert_examples_to_features(examples, config):
    features = []
    for index, example in enumerate(tqdm(examples, desc='Converting Examples')):
        src_seq, tgt_seq = [], []

        for word in example.src.split():
            src_seq.append(word)

        if example.tgt:
            for word in example.tgt.split():
                tgt_seq.append(word)

        src_seq = [config.sos] + src_seq[:config.max_seq_length] + [config.eos]
        tgt_seq = [config.sos] + tgt_seq[:config.max_seq_length] + [config.eos]
        if config.to_lower:
            src_seq = list(map(str.lower, src_seq))
            tgt_seq = list(map(str.lower, tgt_seq))

        src_ids = convert_list(src_seq, config.src_2_id, config.pad_id, config.unk_id)
        tgt_ids = convert_list(tgt_seq, config.tgt_2_id, config.pad_id, config.unk_id)

        features.append(InputFeatures(example.guid, src_ids, tgt_ids))

        if index < 5:
            logger.info(log_title('Examples'))
            logger.info('guid: {}'.format(example.guid))
            logger.info('source input: {}'.format(src_seq))
            logger.info('source ids: {}'.format(src_ids))
            logger.info('target input: {}'.format(tgt_seq))
            logger.info('target ids: {}'.format(tgt_ids))

    return features
Пример #2
0
    def _read_data(self, data_file):
        topic = []
        triple = []
        src = []
        tgt = []

        data_iter = tqdm(list(read_json_lines(data_file)))
        for index, line in enumerate(data_iter):
            topic_seq = ' {} '.format(self.config.sep).join(line['topic'])
            triple_seq = ' {} '.format(self.config.sep).join(
                [' '.join(v) for v in line['triples']])
            src_seq = ' {} '.format(self.config.sep).join(line['src'])
            tgt_seq = line['tgt']

            if self.config.to_lower:
                topic_seq = topic_seq.lower()
                triple_seq = triple_seq.lower()
                src_seq = src_seq.lower()
                tgt_seq = tgt_seq.lower()

            topic_tokens = [self.config.sos
                            ] + topic_seq.split() + [self.config.eos]
            triple_tokens = [self.config.sos] + triple_seq.split(
            )[:self.config.max_triple_length] + [self.config.eos]
            src_tokens = [self.config.sos] + src_seq.split(
            )[-self.config.max_seq_length:] + [self.config.eos]
            tgt_tokens = [self.config.sos] + tgt_seq.split(
            )[:self.config.max_seq_length] + [self.config.eos]

            topic_ids = convert_list(topic_tokens, self.config.word_2_id,
                                     self.config.pad_id, self.config.unk_id)
            triple_ids = convert_list(triple_tokens, self.config.word_2_id,
                                      self.config.pad_id, self.config.unk_id)
            src_ids = convert_list(src_tokens, self.config.word_2_id,
                                   self.config.pad_id, self.config.unk_id)
            tgt_ids = convert_list(tgt_tokens, self.config.word_2_id,
                                   self.config.pad_id, self.config.unk_id)

            topic.append(topic_ids)
            triple.append(triple_ids)
            src.append(src_ids)
            tgt.append(tgt_ids)

            if index < 5:
                logger.info(log_title('Examples'))
                logger.info('topic tokens: {}'.format(topic_tokens))
                logger.info('topic ids: {}'.format(topic_ids))
                logger.info('triple tokens: {}'.format(triple_tokens))
                logger.info('triple ids: {}'.format(triple_ids))
                logger.info('source tokens: {}'.format(src_tokens))
                logger.info('source ids: {}'.format(src_ids))
                logger.info('target tokens: {}'.format(tgt_tokens))
                logger.info('target ids: {}'.format(tgt_ids))

        return topic, triple, src, tgt
Пример #3
0
def main():
    os.makedirs(config.temp_dir, exist_ok=True)
    os.makedirs(config.result_dir, exist_ok=True)
    os.makedirs(config.train_log_dir, exist_ok=True)

    logger.setLevel(logging.INFO)
    init_logger(logging.INFO, 'temp.log.txt', 'w')

    logger.info('preparing data...')
    config.word_2_id, config.id_2_word = read_json_dict(config.vocab_dict)
    config.vocab_size = min(config.vocab_size, len(config.word_2_id))
    config.oov_vocab_size = min(config.oov_vocab_size,
                                len(config.word_2_id) - config.vocab_size)

    embedding_matrix = None
    if args.do_train:
        if os.path.exists(config.glove_file):
            logger.info('loading embedding matrix from file: {}'.format(
                config.glove_file))
            embedding_matrix, config.word_em_size = load_glove_embedding(
                config.glove_file, list(config.word_2_id.keys()))
            logger.info('shape of embedding matrix: {}'.format(
                embedding_matrix.shape))
    else:
        if os.path.exists(config.glove_file):
            with open(config.glove_file, 'r', encoding='utf-8') as fin:
                line = fin.readline()
                config.word_em_size = len(line.strip().split()) - 1

    data_reader = DataReader(config)
    evaluator = Evaluator('tgt')

    logger.info('building model...')
    model = get_model(config, embedding_matrix)
    saver = tf.train.Saver(max_to_keep=10)

    if args.do_train:
        logger.info('loading data...')
        train_data = data_reader.read_train_data()
        valid_data = data_reader.read_valid_data()

        logger.info(log_title('Trainable Variables'))
        for v in tf.trainable_variables():
            logger.info(v)

        logger.info(log_title('Gradients'))
        for g in model.gradients:
            logger.info(g)

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)
            else:
                logger.info('initializing from scratch...')
                tf.global_variables_initializer().run()

            train_writer = tf.summary.FileWriter(config.train_log_dir,
                                                 sess.graph)

            valid_log_history = run_train(sess, model, train_data, valid_data,
                                          saver, evaluator, train_writer)
            save_json(
                valid_log_history,
                os.path.join(config.result_dir, config.current_model,
                             'valid_log_history.json'))

    if args.do_eval:
        logger.info('loading data...')
        valid_data = data_reader.read_valid_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids, valid_loss, valid_accu = run_evaluate(
                    sess, model, valid_data)
                logger.info(
                    'average valid loss: {:>.4f}, average valid accuracy: {:>.4f}'
                    .format(valid_loss, valid_accu))

                logger.info(log_title('Saving Result'))
                save_outputs(predicted_ids, config.id_2_word,
                             config.valid_data, config.valid_outputs)
                results = evaluator.evaluate(config.valid_data,
                                             config.valid_outputs,
                                             config.to_lower)
                save_json(results, config.valid_results)
            else:
                logger.info('model not found!')

    if args.do_test:
        logger.info('loading data...')
        test_data = data_reader.read_test_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids = run_test(sess, model, test_data)

                logger.info(log_title('Saving Result'))
                save_outputs(predicted_ids, config.id_2_word, config.test_data,
                             config.test_outputs)
                results = evaluator.evaluate(config.test_data,
                                             config.test_outputs,
                                             config.to_lower)
                save_json(results, config.test_results)
            else:
                logger.info('model not found!')
Пример #4
0
def run_train(sess,
              model,
              train_data,
              valid_data,
              saver,
              evaluator,
              summary_writer=None):
    flag = 0
    best_valid_result = 0.0
    valid_log_history = defaultdict(list)
    global_step = 0
    for i in range(config.num_epoch):
        logger.info(log_title('Train Epoch: {}'.format(i + 1)))
        steps = 0
        total_loss = 0.0
        total_accu = 0.0
        batch_iter = tqdm(
            list(
                make_batch_iter(list(zip(*train_data)),
                                config.batch_size,
                                shuffle=True)))
        for batch in batch_iter:
            topic, topic_len, triple, triple_len, src, src_len, tgt, tgt_len = make_batch_data(
                batch)

            _, loss, accu, global_step, summary = sess.run(
                [
                    model.train_op, model.loss, model.accu, model.global_step,
                    model.summary
                ],
                feed_dict={
                    model.batch_size: len(topic),
                    model.topic: topic,
                    model.topic_len: topic_len,
                    model.triple: triple,
                    model.triple_len: triple_len,
                    model.src: src,
                    model.src_len: src_len,
                    model.tgt: tgt,
                    model.tgt_len: tgt_len,
                    model.training: True
                })

            steps += 1
            total_loss += loss
            total_accu += accu
            batch_iter.set_description(
                'loss: {:>.4f} accuracy: {:>.4f}'.format(loss, accu))
            if global_step % args.log_steps == 0 and summary_writer is not None:
                summary_writer.add_summary(summary, global_step)
            if global_step % args.save_steps == 0:
                # evaluate saved models after pre-train epochs
                if i < args.pre_train_epochs:
                    saver.save(sess,
                               config.model_file,
                               global_step=global_step)
                else:
                    predicted_ids, valid_loss, valid_accu = run_evaluate(
                        sess, model, valid_data)
                    logger.info(
                        'valid loss: {:>.4f}, valid accuracy: {:>.4f}'.format(
                            valid_loss, valid_accu))

                    save_outputs(predicted_ids, config.id_2_word,
                                 config.valid_data, config.valid_outputs)
                    valid_results = evaluator.evaluate(config.valid_data,
                                                       config.valid_outputs,
                                                       config.to_lower)

                    # early stop
                    if valid_results['BLEU 4'] >= best_valid_result:
                        flag = 0
                        best_valid_result = valid_results['BLEU 4']
                        logger.info('saving model-{}'.format(global_step))
                        saver.save(sess,
                                   config.model_file,
                                   global_step=global_step)
                        save_json(valid_results, config.valid_results)
                    elif flag < args.early_stop:
                        flag += 1
                    elif args.early_stop:
                        return valid_log_history

                    for key, value in valid_results.items():
                        valid_log_history[key].append(value)
                    valid_log_history['loss'].append(valid_loss)
                    valid_log_history['accuracy'].append(valid_accu)
                    valid_log_history['global_step'].append(int(global_step))
        logger.info('train loss: {:>.4f}, train accuracy: {:>.4f}'.format(
            total_loss / steps, total_accu / steps))
    saver.save(sess, config.model_file, global_step=global_step)

    return valid_log_history
Пример #5
0
    if args.do_test:
        logger.info('loading data...')
        test_data = data_reader.read_test_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids = run_test(sess, model, test_data)

                logger.info(log_title('Saving Result'))
                save_outputs(predicted_ids, config.id_2_word, config.test_data,
                             config.test_outputs)
                results = evaluator.evaluate(config.test_data,
                                             config.test_outputs,
                                             config.to_lower)
                save_json(results, config.test_results)
            else:
                logger.info('model not found!')


if __name__ == '__main__':
    main()

    logger.info(log_title('done'))
Пример #6
0
def main():
    os.makedirs(config.temp_dir, exist_ok=True)
    os.makedirs(config.result_dir, exist_ok=True)
    os.makedirs(config.train_log_dir, exist_ok=True)

    logger.setLevel(logging.INFO)
    init_logger(logging.INFO)

    logger.info('loading dict...')
    config.src_2_id, config.id_2_src = read_json_dict(config.src_vocab_dict)
    config.src_vocab_size = min(config.src_vocab_size, len(config.src_2_id))
    config.tgt_2_id, config.id_2_tgt = read_json_dict(config.tgt_vocab_dict)
    config.tgt_vocab_size = min(config.tgt_vocab_size, len(config.tgt_2_id))

    data_reader = DataReader(config)
    evaluator = Evaluator('tgt')

    logger.info('building model...')
    model = get_model(config)
    saver = tf.train.Saver(max_to_keep=10)

    if args.do_train:
        logger.info('loading data...')
        train_data = data_reader.load_train_data()
        valid_data = data_reader.load_valid_data()

        logger.info(log_title('Trainable Variables'))
        for v in tf.trainable_variables():
            logger.info(v)

        logger.info(log_title('Gradients'))
        for g in model.gradients:
            logger.info(g)

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)
            else:
                logger.info('initializing from scratch...')
                tf.global_variables_initializer().run()

            train_writer = tf.summary.FileWriter(config.train_log_dir,
                                                 sess.graph)
            valid_log_history = run_train(sess, model, train_data, valid_data,
                                          saver, evaluator, train_writer)
            save_json(
                valid_log_history,
                os.path.join(config.result_dir, config.current_model,
                             'valid_log_history.json'))

    if args.do_eval:
        logger.info('loading data...')
        valid_data = data_reader.load_valid_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids, valid_loss, valid_accu = run_evaluate(
                    sess, model, valid_data)
                logger.info(
                    'average valid loss: {:>.4f}, average valid accuracy: {:>.4f}'
                    .format(valid_loss, valid_accu))

                logger.info(log_title('Saving Result'))
                save_outputs(predicted_ids, config.id_2_tgt, config.valid_data,
                             config.valid_outputs)
                results = evaluator.evaluate(config.valid_data,
                                             config.valid_outputs,
                                             config.to_lower)
                save_json(results, config.valid_results)
            else:
                logger.info('model not found!')

    if args.do_test:
        logger.info('loading data...')
        test_data = data_reader.load_test_data()

        with tf.Session(config=sess_config) as sess:
            model_file = args.model_file
            if model_file is None:
                model_file = tf.train.latest_checkpoint(
                    os.path.join(config.result_dir, config.current_model))
            if model_file is not None:
                logger.info('loading model from {}...'.format(model_file))
                saver.restore(sess, model_file)

                predicted_ids = run_test(sess, model, test_data)

                logger.info(log_title('Saving Result'))
                save_outputs(predicted_ids, config.id_2_tgt, config.test_data,
                             config.test_outputs)
                results = evaluator.evaluate(config.test_data,
                                             config.test_outputs,
                                             config.to_lower)
                save_json(results, config.test_results)
            else:
                logger.info('model not found!')