def convert_examples_to_features(examples, config): features = [] for index, example in enumerate(tqdm(examples, desc='Converting Examples')): src_seq, tgt_seq = [], [] for word in example.src.split(): src_seq.append(word) if example.tgt: for word in example.tgt.split(): tgt_seq.append(word) src_seq = [config.sos] + src_seq[:config.max_seq_length] + [config.eos] tgt_seq = [config.sos] + tgt_seq[:config.max_seq_length] + [config.eos] if config.to_lower: src_seq = list(map(str.lower, src_seq)) tgt_seq = list(map(str.lower, tgt_seq)) src_ids = convert_list(src_seq, config.src_2_id, config.pad_id, config.unk_id) tgt_ids = convert_list(tgt_seq, config.tgt_2_id, config.pad_id, config.unk_id) features.append(InputFeatures(example.guid, src_ids, tgt_ids)) if index < 5: logger.info(log_title('Examples')) logger.info('guid: {}'.format(example.guid)) logger.info('source input: {}'.format(src_seq)) logger.info('source ids: {}'.format(src_ids)) logger.info('target input: {}'.format(tgt_seq)) logger.info('target ids: {}'.format(tgt_ids)) return features
def _read_data(self, data_file): topic = [] triple = [] src = [] tgt = [] data_iter = tqdm(list(read_json_lines(data_file))) for index, line in enumerate(data_iter): topic_seq = ' {} '.format(self.config.sep).join(line['topic']) triple_seq = ' {} '.format(self.config.sep).join( [' '.join(v) for v in line['triples']]) src_seq = ' {} '.format(self.config.sep).join(line['src']) tgt_seq = line['tgt'] if self.config.to_lower: topic_seq = topic_seq.lower() triple_seq = triple_seq.lower() src_seq = src_seq.lower() tgt_seq = tgt_seq.lower() topic_tokens = [self.config.sos ] + topic_seq.split() + [self.config.eos] triple_tokens = [self.config.sos] + triple_seq.split( )[:self.config.max_triple_length] + [self.config.eos] src_tokens = [self.config.sos] + src_seq.split( )[-self.config.max_seq_length:] + [self.config.eos] tgt_tokens = [self.config.sos] + tgt_seq.split( )[:self.config.max_seq_length] + [self.config.eos] topic_ids = convert_list(topic_tokens, self.config.word_2_id, self.config.pad_id, self.config.unk_id) triple_ids = convert_list(triple_tokens, self.config.word_2_id, self.config.pad_id, self.config.unk_id) src_ids = convert_list(src_tokens, self.config.word_2_id, self.config.pad_id, self.config.unk_id) tgt_ids = convert_list(tgt_tokens, self.config.word_2_id, self.config.pad_id, self.config.unk_id) topic.append(topic_ids) triple.append(triple_ids) src.append(src_ids) tgt.append(tgt_ids) if index < 5: logger.info(log_title('Examples')) logger.info('topic tokens: {}'.format(topic_tokens)) logger.info('topic ids: {}'.format(topic_ids)) logger.info('triple tokens: {}'.format(triple_tokens)) logger.info('triple ids: {}'.format(triple_ids)) logger.info('source tokens: {}'.format(src_tokens)) logger.info('source ids: {}'.format(src_ids)) logger.info('target tokens: {}'.format(tgt_tokens)) logger.info('target ids: {}'.format(tgt_ids)) return topic, triple, src, tgt
def main(): os.makedirs(config.temp_dir, exist_ok=True) os.makedirs(config.result_dir, exist_ok=True) os.makedirs(config.train_log_dir, exist_ok=True) logger.setLevel(logging.INFO) init_logger(logging.INFO, 'temp.log.txt', 'w') logger.info('preparing data...') config.word_2_id, config.id_2_word = read_json_dict(config.vocab_dict) config.vocab_size = min(config.vocab_size, len(config.word_2_id)) config.oov_vocab_size = min(config.oov_vocab_size, len(config.word_2_id) - config.vocab_size) embedding_matrix = None if args.do_train: if os.path.exists(config.glove_file): logger.info('loading embedding matrix from file: {}'.format( config.glove_file)) embedding_matrix, config.word_em_size = load_glove_embedding( config.glove_file, list(config.word_2_id.keys())) logger.info('shape of embedding matrix: {}'.format( embedding_matrix.shape)) else: if os.path.exists(config.glove_file): with open(config.glove_file, 'r', encoding='utf-8') as fin: line = fin.readline() config.word_em_size = len(line.strip().split()) - 1 data_reader = DataReader(config) evaluator = Evaluator('tgt') logger.info('building model...') model = get_model(config, embedding_matrix) saver = tf.train.Saver(max_to_keep=10) if args.do_train: logger.info('loading data...') train_data = data_reader.read_train_data() valid_data = data_reader.read_valid_data() logger.info(log_title('Trainable Variables')) for v in tf.trainable_variables(): logger.info(v) logger.info(log_title('Gradients')) for g in model.gradients: logger.info(g) with tf.Session(config=sess_config) as sess: model_file = args.model_file if model_file is None: model_file = tf.train.latest_checkpoint( os.path.join(config.result_dir, config.current_model)) if model_file is not None: logger.info('loading model from {}...'.format(model_file)) saver.restore(sess, model_file) else: logger.info('initializing from scratch...') tf.global_variables_initializer().run() train_writer = tf.summary.FileWriter(config.train_log_dir, sess.graph) valid_log_history = run_train(sess, model, train_data, valid_data, saver, evaluator, train_writer) save_json( valid_log_history, os.path.join(config.result_dir, config.current_model, 'valid_log_history.json')) if args.do_eval: logger.info('loading data...') valid_data = data_reader.read_valid_data() with tf.Session(config=sess_config) as sess: model_file = args.model_file if model_file is None: model_file = tf.train.latest_checkpoint( os.path.join(config.result_dir, config.current_model)) if model_file is not None: logger.info('loading model from {}...'.format(model_file)) saver.restore(sess, model_file) predicted_ids, valid_loss, valid_accu = run_evaluate( sess, model, valid_data) logger.info( 'average valid loss: {:>.4f}, average valid accuracy: {:>.4f}' .format(valid_loss, valid_accu)) logger.info(log_title('Saving Result')) save_outputs(predicted_ids, config.id_2_word, config.valid_data, config.valid_outputs) results = evaluator.evaluate(config.valid_data, config.valid_outputs, config.to_lower) save_json(results, config.valid_results) else: logger.info('model not found!') if args.do_test: logger.info('loading data...') test_data = data_reader.read_test_data() with tf.Session(config=sess_config) as sess: model_file = args.model_file if model_file is None: model_file = tf.train.latest_checkpoint( os.path.join(config.result_dir, config.current_model)) if model_file is not None: logger.info('loading model from {}...'.format(model_file)) saver.restore(sess, model_file) predicted_ids = run_test(sess, model, test_data) logger.info(log_title('Saving Result')) save_outputs(predicted_ids, config.id_2_word, config.test_data, config.test_outputs) results = evaluator.evaluate(config.test_data, config.test_outputs, config.to_lower) save_json(results, config.test_results) else: logger.info('model not found!')
def run_train(sess, model, train_data, valid_data, saver, evaluator, summary_writer=None): flag = 0 best_valid_result = 0.0 valid_log_history = defaultdict(list) global_step = 0 for i in range(config.num_epoch): logger.info(log_title('Train Epoch: {}'.format(i + 1))) steps = 0 total_loss = 0.0 total_accu = 0.0 batch_iter = tqdm( list( make_batch_iter(list(zip(*train_data)), config.batch_size, shuffle=True))) for batch in batch_iter: topic, topic_len, triple, triple_len, src, src_len, tgt, tgt_len = make_batch_data( batch) _, loss, accu, global_step, summary = sess.run( [ model.train_op, model.loss, model.accu, model.global_step, model.summary ], feed_dict={ model.batch_size: len(topic), model.topic: topic, model.topic_len: topic_len, model.triple: triple, model.triple_len: triple_len, model.src: src, model.src_len: src_len, model.tgt: tgt, model.tgt_len: tgt_len, model.training: True }) steps += 1 total_loss += loss total_accu += accu batch_iter.set_description( 'loss: {:>.4f} accuracy: {:>.4f}'.format(loss, accu)) if global_step % args.log_steps == 0 and summary_writer is not None: summary_writer.add_summary(summary, global_step) if global_step % args.save_steps == 0: # evaluate saved models after pre-train epochs if i < args.pre_train_epochs: saver.save(sess, config.model_file, global_step=global_step) else: predicted_ids, valid_loss, valid_accu = run_evaluate( sess, model, valid_data) logger.info( 'valid loss: {:>.4f}, valid accuracy: {:>.4f}'.format( valid_loss, valid_accu)) save_outputs(predicted_ids, config.id_2_word, config.valid_data, config.valid_outputs) valid_results = evaluator.evaluate(config.valid_data, config.valid_outputs, config.to_lower) # early stop if valid_results['BLEU 4'] >= best_valid_result: flag = 0 best_valid_result = valid_results['BLEU 4'] logger.info('saving model-{}'.format(global_step)) saver.save(sess, config.model_file, global_step=global_step) save_json(valid_results, config.valid_results) elif flag < args.early_stop: flag += 1 elif args.early_stop: return valid_log_history for key, value in valid_results.items(): valid_log_history[key].append(value) valid_log_history['loss'].append(valid_loss) valid_log_history['accuracy'].append(valid_accu) valid_log_history['global_step'].append(int(global_step)) logger.info('train loss: {:>.4f}, train accuracy: {:>.4f}'.format( total_loss / steps, total_accu / steps)) saver.save(sess, config.model_file, global_step=global_step) return valid_log_history
if args.do_test: logger.info('loading data...') test_data = data_reader.read_test_data() with tf.Session(config=sess_config) as sess: model_file = args.model_file if model_file is None: model_file = tf.train.latest_checkpoint( os.path.join(config.result_dir, config.current_model)) if model_file is not None: logger.info('loading model from {}...'.format(model_file)) saver.restore(sess, model_file) predicted_ids = run_test(sess, model, test_data) logger.info(log_title('Saving Result')) save_outputs(predicted_ids, config.id_2_word, config.test_data, config.test_outputs) results = evaluator.evaluate(config.test_data, config.test_outputs, config.to_lower) save_json(results, config.test_results) else: logger.info('model not found!') if __name__ == '__main__': main() logger.info(log_title('done'))
def main(): os.makedirs(config.temp_dir, exist_ok=True) os.makedirs(config.result_dir, exist_ok=True) os.makedirs(config.train_log_dir, exist_ok=True) logger.setLevel(logging.INFO) init_logger(logging.INFO) logger.info('loading dict...') config.src_2_id, config.id_2_src = read_json_dict(config.src_vocab_dict) config.src_vocab_size = min(config.src_vocab_size, len(config.src_2_id)) config.tgt_2_id, config.id_2_tgt = read_json_dict(config.tgt_vocab_dict) config.tgt_vocab_size = min(config.tgt_vocab_size, len(config.tgt_2_id)) data_reader = DataReader(config) evaluator = Evaluator('tgt') logger.info('building model...') model = get_model(config) saver = tf.train.Saver(max_to_keep=10) if args.do_train: logger.info('loading data...') train_data = data_reader.load_train_data() valid_data = data_reader.load_valid_data() logger.info(log_title('Trainable Variables')) for v in tf.trainable_variables(): logger.info(v) logger.info(log_title('Gradients')) for g in model.gradients: logger.info(g) with tf.Session(config=sess_config) as sess: model_file = args.model_file if model_file is None: model_file = tf.train.latest_checkpoint( os.path.join(config.result_dir, config.current_model)) if model_file is not None: logger.info('loading model from {}...'.format(model_file)) saver.restore(sess, model_file) else: logger.info('initializing from scratch...') tf.global_variables_initializer().run() train_writer = tf.summary.FileWriter(config.train_log_dir, sess.graph) valid_log_history = run_train(sess, model, train_data, valid_data, saver, evaluator, train_writer) save_json( valid_log_history, os.path.join(config.result_dir, config.current_model, 'valid_log_history.json')) if args.do_eval: logger.info('loading data...') valid_data = data_reader.load_valid_data() with tf.Session(config=sess_config) as sess: model_file = args.model_file if model_file is None: model_file = tf.train.latest_checkpoint( os.path.join(config.result_dir, config.current_model)) if model_file is not None: logger.info('loading model from {}...'.format(model_file)) saver.restore(sess, model_file) predicted_ids, valid_loss, valid_accu = run_evaluate( sess, model, valid_data) logger.info( 'average valid loss: {:>.4f}, average valid accuracy: {:>.4f}' .format(valid_loss, valid_accu)) logger.info(log_title('Saving Result')) save_outputs(predicted_ids, config.id_2_tgt, config.valid_data, config.valid_outputs) results = evaluator.evaluate(config.valid_data, config.valid_outputs, config.to_lower) save_json(results, config.valid_results) else: logger.info('model not found!') if args.do_test: logger.info('loading data...') test_data = data_reader.load_test_data() with tf.Session(config=sess_config) as sess: model_file = args.model_file if model_file is None: model_file = tf.train.latest_checkpoint( os.path.join(config.result_dir, config.current_model)) if model_file is not None: logger.info('loading model from {}...'.format(model_file)) saver.restore(sess, model_file) predicted_ids = run_test(sess, model, test_data) logger.info(log_title('Saving Result')) save_outputs(predicted_ids, config.id_2_tgt, config.test_data, config.test_outputs) results = evaluator.evaluate(config.test_data, config.test_outputs, config.to_lower) save_json(results, config.test_results) else: logger.info('model not found!')