def train(): mode = tf.estimator.ModeKeys.TRAIN os.makedirs(FLAGS.logdir, exist_ok=True) logging.info('Begin training...') # prepare training data logging.info('Load training data...') train_pairs = [] train_files = FLAGS.train.split(';') for file in train_files: with open(file, 'r', encoding='utf-8') as f: for line in f: s, t = line.strip().split('\t') s = s.split() if len(s) > FLAGS.maxlen: continue t = t.split() if len(t) > FLAGS.maxlen: continue train_pairs.append((s, t)) train_batch = data_help.MyBatch(train_pairs, FLAGS.batch_size, token2idx) num_train_batches = train_batch.num_batch num_train_samples = train_batch.num_samples logging.info('Train batches: {} Train samples: {}'.format( num_train_batches, num_train_samples)) # model logging.info('Build model graph...') x = tf.placeholder(dtype=tf.int32, shape=(None, None), name='x') x_seqlens = tf.placeholder(dtype=tf.int32, shape=(None, ), name='x_seqlens') y = tf.placeholder(dtype=tf.int32, shape=(None, None), name='y') decoder_inputs = tf.placeholder(dtype=tf.int32, shape=(None, None), name='decoder_inputs') y_seqlens = tf.placeholder(dtype=tf.int32, shape=(None, ), name='y_seqlens') with tf.variable_scope('encoder'): enc_emb = tf.nn.embedding_lookup(embed, x) outputs, state, sequence_length = encoder.encode(enc_emb, x_seqlens, mode=mode) with tf.variable_scope('decoder'): dec_emb = tf.nn.embedding_lookup(embed, decoder_inputs) logits, _, _ = decoder.decode(dec_emb, y_seqlens, vocab_size=FLAGS.vocab_size, memory=outputs, mode=mode, memory_sequence_length=x_seqlens, initial_state=state) loss, normalizer, _ = opennmt.utils.losses.cross_entropy_sequence_loss( logits, y, y_seqlens, average_in_time=True, mode=mode) loss /= normalizer global_step = tf.train.get_or_create_global_step() lr = tf.Variable(FLAGS.lr, trainable=False) learning_rate_decay_op = lr.assign(lr * FLAGS.learning_rate_decay_factor) optimizer = tf.train.AdamOptimizer(lr) train_op = optimizer.minimize(loss, global_step=global_step) # initialize model saver = tf.train.Saver(max_to_keep=FLAGS.num_epochs) gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True with tf.Session(config=gpu_config) as sess: ckpt = tf.train.latest_checkpoint(FLAGS.logdir) if ckpt is None: logging.info("Initializing from scratch") sess.run(tf.global_variables_initializer()) else: logging.info("Restore from {}".format(ckpt)) saver.restore(sess, ckpt) # train old_loss = [] total_steps = FLAGS.num_epochs * num_train_batches _gs = sess.run(global_step) for i in range(_gs, total_steps + 1): # get feed data batch = train_batch.get_next() _x, _x_seqlens, _y, _decoder_inputs, _y_seqlens = train_batch.process_batch( batch) _, _gs = sess.run( [train_op, global_step], feed_dict={ x: _x, x_seqlens: _x_seqlens, y: _y, y_seqlens: _y_seqlens, decoder_inputs: _decoder_inputs }) epoch = math.ceil(_gs / num_train_batches) _lr, _loss = sess.run( [lr, loss], feed_dict={ x: _x, x_seqlens: _x_seqlens, y: _y, y_seqlens: _y_seqlens, decoder_inputs: _decoder_inputs }) if (_gs + 1) % FLAGS.print_period == 0: logging.info("global step {}, lr {}, loss {}".format( _gs, _lr, _loss)) # decay lr if len(old_loss) > 5 and _loss > max(old_loss[-5:]): sess.run(learning_rate_decay_op) old_loss.append(_loss) if _gs and _gs % num_train_batches == 0: logging.info("epoch {} is done".format(epoch)) ckpt_path = os.path.join(FLAGS.logdir, 'ckpt') saver.save(sess, ckpt_path, global_step=_gs)
def eval(): os.makedirs(FLAGS.testdir, exist_ok=True) mode = tf.estimator.ModeKeys.PREDICT logging.info('Begin testing...') logging.info('Load test data...') test_pairs = [] with open(FLAGS.test, 'r', encoding='utf-8') as f: for line in f: s, t = line.strip().split('\t') s = s.split() t = t.split() test_pairs.append((s, t)) test_batch = data_help.MyBatch(test_pairs, FLAGS.batch_size, token2idx) num_test_batches = test_batch.num_batch num_test_samples = test_batch.num_samples logging.info('Test batches: {} Test samples: {}'.format( num_test_batches, num_test_samples)) # model logging.info('Build model graph') x = tf.placeholder(tf.int32, shape=(None, None), name='x') x_seqlens = tf.placeholder(tf.int32, shape=(None, ), name='x_seqlens') y = tf.placeholder(dtype=tf.int32, shape=(None, None), name='y') decoder_inputs = tf.placeholder(dtype=tf.int32, shape=(None, None), name='decoder_inputs') y_seqlens = tf.placeholder(dtype=tf.int32, shape=(None, ), name='y_seqlens') with tf.variable_scope('encoder'): enc_emb = tf.nn.embedding_lookup(embed, x) outputs, state, sequence_length = encoder.encode(enc_emb, x_seqlens, mode=mode) with tf.variable_scope('decoder'): start_tokens = tf.fill([FLAGS.batch_size], data_help.SOS_ID) end_token = data_help.EOS_ID target_ids, _, target_length, _ = decoder.dynamic_decode_and_search( embed, start_tokens, end_token, vocab_size=FLAGS.vocab_size, beam_width=1, memory=outputs, initial_state=state, memory_sequence_length=x_seqlens) gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True with tf.Session(config=gpu_config) as sess: # load ckpt logging.info('Loading checkpoint') if FLAGS.ckpt_path != '': ckpt_path = os.path.join(FLAGS.logdir, FLAGS.ckpt_path) else: ckpt_path = tf.train.latest_checkpoint(FLAGS.logdir) logging.info('Using checkpoint {}'.format(ckpt_path)) saver = tf.train.Saver() saver.restore(sess, ckpt_path) logging.info('Inference...') hypotheses = [] for _ in range(num_test_batches): # get feed data batch = test_batch.get_next() _x, _x_seqlens, _y, _decoder_inputs, _y_seqlens = test_batch.process_batch( batch) h = sess.run(target_ids, feed_dict={ x: _x, x_seqlens: _x_seqlens, y: _y, y_seqlens: _y_seqlens, decoder_inputs: _decoder_inputs }) h = h[:, 0, :] hypotheses.extend(h.tolist()) hypotheses = my_utils.postprocess(hypotheses, idx2token) hypotheses = hypotheses[:num_test_samples] logging.info('Writing to result...') with open(os.path.join(FLAGS.testdir, 'result'), 'w', encoding='utf-8') as fw: fw.write('\n'.join(hypotheses)) logging.info('Inference Done.') # calculate bleu logging.info('Calculating bleu score') golden_file = os.path.join(FLAGS.testdir, 'gloden_temp') predict_file = os.path.join(FLAGS.testdir, 'result') bleu_file = os.path.join(FLAGS.testdir, 'bleu') with open(golden_file, 'w', encoding='utf-8') as fw: for s, t in test_pairs: fw.write(' '.join(t) + '\n') os.system('perl multi-bleu.perl {} < {} > {}'.format( golden_file, predict_file, bleu_file)) logging.info('Bleu file is: {}'.format(bleu_file)) with open(bleu_file) as f: logging.info(f.read())