예제 #1
0
def train(params):
    # status measure
    if params.recorder.estop or \
            params.recorder.epoch > params.epoches or \
            params.recorder.step > params.max_training_steps:
        tf.logging.info(
            "Stop condition reached, you have finished training your model.")
        return 0.

    # loading dataset
    tf.logging.info("Begin Loading Training and Dev Dataset")
    start_time = time.time()

    bert_vocab, tokenizer = None, None

    if params.enable_bert:
        bert_vocab = params.bert.vocab
        tokenizer = bert.load_tokenizer(params)

    dataset = Dataset(
        get_task(params, True),
        params.max_len,
        params.max_w_len,
        params.max_p_num,
        params.word_vocab,
        bert_vocab,
        tokenizer,
        enable_hierarchy=params.enable_hierarchy,
        char_vocab=params.char_vocab if params.use_char else None,
        enable_char=params.use_char,
        batch_or_token=params.batch_or_token)

    tf.logging.info(
        "End Loading dataset, within {} seconds".format(time.time() -
                                                        start_time))

    # Build Graph
    with tf.Graph().as_default():
        lr = tf.placeholder(tf.float32, [], "learn_rate")

        features = []
        for fidx in range(max(len(params.gpus), 1)):
            feature = {
                "t": tf.placeholder(tf.int32, [None, None],
                                    "t_{}".format(fidx)),
                "l": tf.placeholder(tf.int32, [None], "l_{}".format(fidx)),
            }
            if params.use_char:
                feature["c"] = tf.placeholder(tf.int32, [None, None, None],
                                              "c_{}".format(fidx))

            if params.enable_bert:
                feature["s"] = tf.placeholder(tf.int32, [None, None],
                                              "s_{}".format(fidx))
                feature["sb"] = tf.placeholder(tf.int32, [None, None],
                                               "sb_{}".format(fidx))

            features.append(feature)

        # session info
        sess = util.get_session(params.gpus)

        tf.logging.info("Begining Building Training Graph")
        start_time = time.time()

        # create global step
        global_step = tf.train.get_or_create_global_step()

        # set up optimizer
        optimizer = tf.train.AdamOptimizer(lr,
                                           beta1=params.beta1,
                                           beta2=params.beta2,
                                           epsilon=params.epsilon)

        # set up training graph
        loss, gradients = tower_train_graph(features, optimizer, model, params)

        # apply pseudo cyclic parallel operation
        vle, ops = cycle.create_train_op({"loss": loss}, gradients, optimizer,
                                         global_step, params)

        tf.logging.info(
            "End Building Training Graph, within {} seconds".format(
                time.time() - start_time))

        tf.logging.info("Begin Building Inferring Graph")
        start_time = time.time()

        # set up infer graph
        eval_pred = tower_infer_graph(features, model, params)

        tf.logging.info(
            "End Building Inferring Graph, within {} seconds".format(
                time.time() - start_time))

        # initialize the model
        sess.run(tf.global_variables_initializer())

        # log parameters
        util.variable_printer()

        # create saver
        train_saver = saver.Saver(checkpoints=params.checkpoints,
                                  output_dir=params.output_dir)

        tf.logging.info("Training")
        cycle_counter = 1
        cum_loss, cum_gnorm = [], []

        # restore parameters
        tf.logging.info("Trying restore existing parameters")
        if params.enable_bert:
            bert.load_model(sess, params.bert_dir)
        train_saver.restore(sess)

        # setup learning rate
        adapt_lr = lrs.get_lr(params)

        start_time = time.time()
        start_epoch = params.recorder.epoch
        data_on_gpu = []
        for epoch in range(start_epoch, params.epoches + 1):

            params.recorder.epoch = epoch

            tf.logging.info("Training the model for epoch {}".format(epoch))
            size = params.batch_size if params.batch_or_token == 'batch' \
                else params.token_size
            train_batcher = dataset.batcher(size,
                                            buffer_size=params.buffer_size,
                                            shuffle=params.shuffle_batch,
                                            train="train")
            train_queue = queuer.EnQueuer(
                train_batcher,
                multiprocessing=params.data_multiprocessing,
                random_seed=params.random_seed)
            train_queue.start(workers=params.nthreads,
                              max_queue_size=params.max_queue_size)

            adapt_lr.before_epoch(eidx=epoch)
            for lidx, data in enumerate(train_queue.get()):

                if params.train_continue:
                    if lidx <= params.recorder.lidx:
                        segments = params.recorder.lidx // 5
                        if params.recorder.lidx < 5 or lidx % segments == 0:
                            tf.logging.info(
                                "Passing {}-th index according to record".
                                format(lidx))
                        continue
                params.recorder.lidx = lidx

                data_on_gpu.append(data)
                # use multiple gpus, and data samples is not enough
                # make sure the data is fully added
                # The actual batch size: batch_size * num_gpus * update_cycle
                if len(params.gpus) > 0 and len(data_on_gpu) < len(
                        params.gpus):
                    continue

                if cycle_counter == 1:
                    sess.run(ops["zero_op"])

                    # calculate adaptive learning rate
                    adapt_lr.step(params.recorder.step)

                feed_dicts = {}
                for fidx, data in enumerate(data_on_gpu):
                    # define feed_dict
                    feed_dict = {
                        features[fidx]["t"]: data['token_ids'],
                        features[fidx]["l"]: data['l_id'],
                        lr: adapt_lr.get_lr(),
                    }
                    if params.use_char:
                        feed_dict[features[fidx]["c"]] = data['char_ids']
                    if params.enable_bert:
                        feed_dict[features[fidx]["s"]] = data['subword_ids']
                        feed_dict[features[fidx]["sb"]] = data['subword_back']

                    feed_dicts.update(feed_dict)

                # reset data points
                data_on_gpu = []

                if cycle_counter < params.update_cycle:
                    sess.run(ops["collect_op"], feed_dict=feed_dicts)
                if cycle_counter == params.update_cycle:
                    cycle_counter = 0

                    _, loss, gnorm, gstep = sess.run([
                        ops["train_op"], vle["loss"], vle["gradient_norm"],
                        global_step
                    ],
                                                     feed_dict=feed_dicts)

                    if np.isnan(loss) or np.isinf(loss):
                        tf.logging.error("Nan or Inf raised")
                        params.recorder.estop = True
                        break

                    cum_loss.append(loss)
                    cum_gnorm.append(gnorm)

                    if gstep % params.disp_freq == 0:
                        end_time = time.time()
                        tf.logging.info(
                            "{} Epoch {}, GStep {}~{}, LStep {}~{}, "
                            "Loss {:.3f}, GNorm {:.3f}, Lr {:.5f}, "
                            "Document {}, UD {:.3f} s".format(
                                util.time_str(end_time), epoch,
                                gstep - params.disp_freq + 1, gstep,
                                lidx - params.disp_freq + 1, lidx,
                                np.mean(cum_loss), np.mean(cum_gnorm),
                                adapt_lr.get_lr(), data['token_ids'].shape,
                                end_time - start_time))
                        start_time = time.time()
                        cum_loss, cum_gnorm = [], []

                    # trigger model saver
                    if gstep > 0 and gstep % params.save_freq == 0:
                        train_saver.save(sess, gstep)
                        params.recorder.save_to_json(
                            os.path.join(params.output_dir, "record.json"))

                    # trigger model evaluation
                    if gstep > 0 and gstep % params.eval_freq == 0:

                        if params.ema_decay > 0.:
                            sess.run(ops['ema_backup_op'])
                            sess.run(ops['ema_assign_op'])

                        tf.logging.info("Start Evaluating")
                        eval_start_time = time.time()
                        predictions, score = evalu.predict(sess,
                                                           features,
                                                           eval_pred,
                                                           dataset,
                                                           params,
                                                           train="dev")
                        eval_end_time = time.time()
                        tf.logging.info("End Evaluating")

                        tf.logging.info(
                            "{} GStep {}, Score {}, Duration {:.3f} s".format(
                                util.time_str(eval_end_time), gstep, score,
                                eval_end_time - eval_start_time))

                        if params.ema_decay > 0.:
                            sess.run(ops['ema_restore_op'])

                        # save eval translation
                        evalu.dump_predictions(
                            predictions,
                            os.path.join(params.output_dir,
                                         "eval-{}.trans.txt".format(gstep)))

                        # save parameters
                        train_saver.save(sess, gstep, score)

                        # check for early stopping
                        valid_scores = [
                            v[1] for v in params.recorder.valid_script_scores
                        ]
                        if len(valid_scores
                               ) == 0 or score > np.max(valid_scores):
                            params.recorder.bad_counter = 0
                        else:
                            params.recorder.bad_counter += 1

                            if params.recorder.bad_counter > params.estop_patience:
                                params.recorder.estop = True
                                break

                        params.recorder.history_scores.append(
                            (gstep, float(score)))
                        params.recorder.valid_script_scores.append(
                            (gstep, float(score)))
                        params.recorder.save_to_json(
                            os.path.join(params.output_dir, "record.json"))

                        # handle the learning rate decay in a typical manner
                        adapt_lr.after_eval(float(score))

                    # trigger stopping
                    if gstep >= params.max_training_steps:
                        params.recorder.estop = True
                        break

                    # should be equal to global_step
                    params.recorder.step += 1.0

                cycle_counter += 1

            train_queue.stop()

            if params.recorder.estop:
                tf.logging.info("Early Stopped!")
                break

            # reset to 0
            params.recorder.lidx = -1

            adapt_lr.after_epoch(eidx=epoch)

    # Final Evaluation
    tf.logging.info("Start Evaluating")
    if params.ema_decay > 0.:
        sess.run(ops['ema_backup_op'])
        sess.run(ops['ema_assign_op'])

    gstep = int(params.recorder.step + 1)
    eval_start_time = time.time()
    predictions, score = evalu.predict(sess,
                                       features,
                                       eval_pred,
                                       dataset,
                                       params,
                                       train="dev")
    eval_end_time = time.time()
    tf.logging.info("End Evaluating")
    tf.logging.info("{} GStep {}, Score {}, Duration {:.3f} s".format(
        util.time_str(eval_end_time), gstep, score,
        eval_end_time - eval_start_time))

    # save eval translation
    evalu.dump_predictions(
        predictions,
        os.path.join(params.output_dir, "eval-{}.trans.txt".format(gstep)))

    if params.ema_decay > 0.:
        sess.run(ops['ema_restore_op'])

    tf.logging.info("Your training is finished :)")

    return train_saver.best_score
예제 #2
0
파일: main.py 프로젝트: hardik140397/zero
def train(params):
    # status measure
    if params.recorder.estop or \
            params.recorder.epoch > params.epoches or \
            params.recorder.step > params.max_training_steps:
        tf.logging.info(
            "Stop condition reached, you have finished training your model.")
        return 0.

    # loading dataset
    tf.logging.info("Begin Loading Training and Dev Dataset")
    start_time = time.time()
    train_dataset = Dataset(params.src_train_file,
                            params.tgt_train_file,
                            params.src_vocab,
                            params.tgt_vocab,
                            params.max_len,
                            batch_or_token=params.batch_or_token)
    dev_dataset = Dataset(params.src_dev_file,
                          params.src_dev_file,
                          params.src_vocab,
                          params.src_vocab,
                          1e6,
                          batch_or_token='batch')
    tf.logging.info(
        "End Loading dataset, within {} seconds".format(time.time() -
                                                        start_time))

    # Build Graph
    with tf.Graph().as_default():
        lr = tf.placeholder(tf.float32, [], "learn_rate")
        train_features = {
            "source": tf.placeholder(tf.int32, [None, None], "source"),
            "target": tf.placeholder(tf.int32, [None, None], "target"),
        }
        eval_features = {
            "source": tf.placeholder(tf.int32, [None, None], "source"),
        }

        # session info
        sess = util.get_session(params.gpus)

        tf.logging.info("Begining Building Training Graph")
        start_time = time.time()

        # create global step
        global_step = tf.train.get_or_create_global_step()

        # set up optimizer
        optimizer = tf.train.AdamOptimizer(lr,
                                           beta1=params.beta1,
                                           beta2=params.beta2,
                                           epsilon=params.epsilon)

        # set up training graph
        loss, gradients = tower_train_graph(train_features, optimizer, params)

        # apply pseudo cyclic parallel operation
        vle, ops = cycle.create_train_op({"loss": loss}, gradients, optimizer,
                                         global_step, params)

        tf.logging.info(
            "End Building Training Graph, within {} seconds".format(
                time.time() - start_time))

        tf.logging.info("Begin Building Inferring Graph")
        start_time = time.time()

        # set up infer graph
        eval_seqs, eval_scores, eval_mask = tower_infer_graph(
            eval_features, params)

        tf.logging.info(
            "End Building Inferring Graph, within {} seconds".format(
                time.time() - start_time))

        # initialize the model
        sess.run(tf.global_variables_initializer())

        # log parameters
        util.variable_printer()

        # create saver
        train_saver = saver.Saver(checkpoints=params.checkpoints,
                                  output_dir=params.output_dir)

        tf.logging.info("Training")
        lrate = params.lrate
        cycle_counter = 1
        cum_loss = []
        cum_gnorm = []

        # restore parameters
        tf.logging.info("Trying restore existing parameters")
        train_saver.restore(sess)

        start_time = time.time()
        for epoch in range(1, params.epoches + 1):

            if epoch < params.recorder.epoch:
                tf.logging.info(
                    "Passing {}-th epoch according to record".format(epoch))
                continue
            params.recorder.epoch = epoch

            tf.logging.info("Training the model for epoch {}".format(epoch))
            size = params.batch_size if params.batch_or_token == 'batch' \
                else params.token_size
            train_batcher = train_dataset.batcher(
                size,
                buffer_size=params.buffer_size,
                shuffle=params.shuffle_batch)
            train_queue = queuer.EnQueuer(train_batcher)
            train_queue.start(workers=params.nthreads,
                              max_queue_size=params.max_queue_size)

            for lidx, data in enumerate(train_queue.get()):

                if lidx <= params.recorder.lidx:
                    segments = params.recorder.lidx // 5
                    if params.recorder.lidx < 5 or lidx % segments == 0:
                        tf.logging.info(
                            "Passing {}-th index according to record".format(
                                lidx))
                    continue
                params.recorder.lidx = lidx

                # define feed_dict
                feed_dict = {
                    train_features["source"]: data['src'],
                    train_features["target"]: data['tgt'],
                    lr: lrate,
                }

                if cycle_counter == 1:
                    sess.run(ops["zero_op"])
                if cycle_counter < params.update_cycle:
                    sess.run(ops["collect_op"], feed_dict=feed_dict)
                if cycle_counter == params.update_cycle:
                    cycle_counter = 0
                    _, loss, gnorm, gstep, glr = sess.run([
                        ops["train_op"], vle["loss"], vle["gradient_norm"],
                        global_step, lr
                    ],
                                                          feed_dict=feed_dict)
                    params.recorder.step = gstep

                    cum_loss.append(loss)
                    cum_gnorm.append(gnorm)

                    if gstep % params.disp_freq == 0:
                        end_time = time.time()
                        tf.logging.info(
                            "{} Epoch {}, GStep {}~{}, LStep {}~{}, "
                            "Loss {:.3f}, GNorm {:.3f}, Lr {:.5f}, Duration {:.3f} s"
                            .format(
                                util.time_str(end_time), epoch,
                                gstep - params.disp_freq + 1, gstep, lidx -
                                params.disp_freq * params.update_cycle + 1,
                                lidx, np.mean(cum_loss), np.mean(cum_gnorm),
                                glr, end_time - start_time))
                        start_time = time.time()
                        cum_loss = []
                        cum_gnorm = []

                    if gstep > 0 and gstep % params.eval_freq == 0:
                        eval_start_time = time.time()
                        tranes, scores, indices = evalu.decoding(
                            sess, eval_features, eval_seqs, eval_scores,
                            eval_mask, dev_dataset, params)
                        bleu = evalu.eval_metric(tranes,
                                                 params.tgt_dev_file,
                                                 indices=indices)
                        eval_end_time = time.time()
                        tf.logging.info(
                            "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s"
                            .format(util.time_str(eval_end_time), gstep,
                                    np.mean(scores), bleu,
                                    eval_end_time - eval_start_time))

                        params.recorder.history_scores.append(
                            (gstep, float(np.mean(scores))))
                        params.recorder.valid_script_scores.append(
                            (gstep, float(bleu)))
                        params.recorder.save_to_json(
                            os.path.join(params.output_dir, "record.json"))

                        # save eval translation
                        evalu.dump_tanslation(
                            tranes,
                            os.path.join(params.output_dir,
                                         "eval-{}.trans.txt".format(gstep)),
                            indices=indices)

                        train_saver.save(sess, gstep, bleu)

                    if gstep > 0 and gstep % params.sample_freq == 0:
                        decode_seqs, decode_scores, decode_mask = sess.run(
                            [eval_seqs, eval_scores, eval_mask],
                            feed_dict={eval_features["source"]: data['src']})
                        tranes, scores = evalu.decode_hypothesis(
                            decode_seqs,
                            decode_scores,
                            params,
                            mask=decode_mask)
                        for sidx in range(min(5, len(scores))):
                            sample_source = evalu.decode_target_token(
                                data['src'][sidx], params.src_vocab)
                            tf.logging.info("{}-th Source: {}".format(
                                sidx, ' '.join(sample_source)))
                            sample_target = evalu.decode_target_token(
                                data['tgt'][sidx], params.tgt_vocab)
                            tf.logging.info("{}-th Target: {}".format(
                                sidx, ' '.join(sample_target)))
                            sample_trans = tranes[sidx]
                            tf.logging.info("{}-th Translation: {}".format(
                                sidx, ' '.join(sample_trans)))

                    if gstep >= params.max_training_steps:
                        break

                cycle_counter += 1

            train_queue.stop()

            # reset to 0
            params.recorder.lidx = -1

            # handle the learning rate decay in a typical manner
            lrate = lrate / 2.

    tf.logging.info("Anyway, your training is finished :)")

    return train_saver.best_score
예제 #3
0
def train(params):
    # status measure
    if params.recorder.estop or \
            params.recorder.epoch > params.epoches or \
            params.recorder.step > params.max_training_steps:
        tf.logging.info(
            "Stop condition reached, you have finished training your model.")
        return 0.

    # loading dataset
    tf.logging.info("Begin Loading Training and Dev Dataset")
    start_time = time.time()
    train_dataset = Dataset(params.src_train_file,
                            params.tgt_train_file,
                            params.src_vocab,
                            params.tgt_vocab,
                            params.max_len,
                            batch_or_token=params.batch_or_token,
                            data_leak_ratio=params.data_leak_ratio)
    dev_dataset = Dataset(params.src_dev_file,
                          params.src_dev_file,
                          params.src_vocab,
                          params.src_vocab,
                          params.eval_max_len,
                          batch_or_token='batch',
                          data_leak_ratio=params.data_leak_ratio)
    tf.logging.info(
        "End Loading dataset, within {} seconds".format(time.time() -
                                                        start_time))

    # Build Graph
    with tf.Graph().as_default():
        lr = tf.placeholder(tf.as_dtype(dtype.floatx()), [], "learn_rate")

        # shift automatically sliced multi-gpu process into `zero` manner :)
        features = []
        for fidx in range(max(len(params.gpus), 1)):
            feature = {
                "source": tf.placeholder(tf.int32, [None, None], "source"),
                "target": tf.placeholder(tf.int32, [None, None], "target"),
            }
            features.append(feature)

        # session info
        sess = util.get_session(params.gpus)

        tf.logging.info("Begining Building Training Graph")
        start_time = time.time()

        # create global step
        global_step = tf.train.get_or_create_global_step()

        # set up optimizer
        optimizer = tf.train.AdamOptimizer(lr,
                                           beta1=params.beta1,
                                           beta2=params.beta2,
                                           epsilon=params.epsilon)

        # get graph
        graph = model.get_model(params.model_name)

        # set up training graph
        loss, gradients = tower_train_graph(features, optimizer, graph, params)

        # apply pseudo cyclic parallel operation
        vle, ops = cycle.create_train_op({"loss": loss}, gradients, optimizer,
                                         global_step, params)

        tf.logging.info(
            "End Building Training Graph, within {} seconds".format(
                time.time() - start_time))

        tf.logging.info("Begin Building Inferring Graph")
        start_time = time.time()

        # set up infer graph
        eval_seqs, eval_scores = tower_infer_graph(features, graph, params)

        tf.logging.info(
            "End Building Inferring Graph, within {} seconds".format(
                time.time() - start_time))

        # initialize the model
        sess.run(tf.global_variables_initializer())

        # log parameters
        util.variable_printer()

        # create saver
        train_saver = saver.Saver(
            checkpoints=params.checkpoints,
            output_dir=params.output_dir,
            best_checkpoints=params.best_checkpoints,
        )

        tf.logging.info("Training")
        cycle_counter = 0
        data_on_gpu = []
        cum_tokens = []

        # restore parameters
        tf.logging.info("Trying restore pretrained parameters")
        train_saver.restore(sess, path=params.pretrained_model)

        tf.logging.info("Trying restore existing parameters")
        train_saver.restore(sess)

        # setup learning rate
        params.lrate = params.recorder.lrate
        adapt_lr = lrs.get_lr(params)

        start_time = time.time()
        start_epoch = params.recorder.epoch
        for epoch in range(start_epoch, params.epoches + 1):

            params.recorder.epoch = epoch

            tf.logging.info("Training the model for epoch {}".format(epoch))
            size = params.batch_size if params.batch_or_token == 'batch' \
                else params.token_size

            train_queue = queuer.EnQueuer(
                train_dataset.batcher(size,
                                      buffer_size=params.buffer_size,
                                      shuffle=params.shuffle_batch,
                                      train=True),
                lambda x: x,
                worker_processes_num=params.process_num,
                input_queue_size=params.input_queue_size,
                output_queue_size=params.output_queue_size,
            )

            adapt_lr.before_epoch(eidx=epoch)

            for lidx, data in enumerate(train_queue):

                if params.train_continue:
                    if lidx <= params.recorder.lidx:
                        segments = params.recorder.lidx // 5
                        if params.recorder.lidx < 5 or lidx % segments == 0:
                            tf.logging.info(
                                "{} Passing {}-th index according to record".
                                format(util.time_str(time.time()), lidx))

                        continue

                params.recorder.lidx = lidx

                data_on_gpu.append(data)
                # use multiple gpus, and data samples is not enough
                # make sure the data is fully added
                # The actual batch size: batch_size * num_gpus * update_cycle
                if len(params.gpus) > 0 and len(data_on_gpu) < len(
                        params.gpus):
                    continue

                # increase the counter by 1
                cycle_counter += 1

                if cycle_counter == 1:
                    # calculate adaptive learning rate
                    adapt_lr.step(params.recorder.step)

                    # clear internal states
                    sess.run(ops["zero_op"])

                # data feeding to gpu placeholders
                feed_dicts = {}
                for fidx, shard_data in enumerate(data_on_gpu):
                    # define feed_dict
                    feed_dict = {
                        features[fidx]["source"]: shard_data["src"],
                        features[fidx]["target"]: shard_data["tgt"],
                        lr: adapt_lr.get_lr(),
                    }
                    feed_dicts.update(feed_dict)

                    # collect target tokens
                    cum_tokens.append(np.sum(shard_data['tgt'] > 0))

                # reset data points on gpus
                data_on_gpu = []

                # internal accumulative gradient collection
                if cycle_counter < params.update_cycle:
                    sess.run(ops["collect_op"], feed_dict=feed_dicts)

                # at the final step, update model parameters
                if cycle_counter == params.update_cycle:
                    cycle_counter = 0

                    # directly update parameters, usually this works well
                    if not params.safe_nan:
                        _, loss, gnorm, pnorm, gstep = sess.run(
                            [
                                ops["train_op"], vle["loss"],
                                vle["gradient_norm"], vle["parameter_norm"],
                                global_step
                            ],
                            feed_dict=feed_dicts)

                        if np.isnan(loss) or np.isinf(loss) or np.isnan(
                                gnorm) or np.isinf(gnorm):
                            tf.logging.error(
                                "Nan or Inf raised! Loss {} GNorm {}.".format(
                                    loss, gnorm))
                            params.recorder.estop = True
                            break
                    else:
                        # Notice, applying safe nan can help train the big model, but sacrifice speed
                        loss, gnorm, pnorm, gstep = sess.run(
                            [
                                vle["loss"], vle["gradient_norm"],
                                vle["parameter_norm"], global_step
                            ],
                            feed_dict=feed_dicts)

                        if np.isnan(loss) or np.isinf(loss) or np.isnan(gnorm) or np.isinf(gnorm) \
                                or gnorm > params.gnorm_upper_bound:
                            tf.logging.error(
                                "Nan or Inf raised, GStep {} is passed! Loss {} GNorm {}."
                                .format(gstep, loss, gnorm))
                            continue

                        sess.run(ops["train_op"], feed_dict=feed_dicts)

                    if gstep % params.disp_freq == 0:
                        end_time = time.time()
                        tf.logging.info(
                            "{} Epoch {}, GStep {}~{}, LStep {}~{}, "
                            "Loss {:.3f}, GNorm {:.3f}, PNorm {:.3f}, Lr {:.5f}, "
                            "Src {}, Tgt {}, Tokens {}, UD {:.3f} s".format(
                                util.time_str(end_time), epoch,
                                gstep - params.disp_freq + 1, gstep,
                                lidx - params.disp_freq + 1, lidx, loss, gnorm,
                                pnorm, adapt_lr.get_lr(),
                                data['src'].shape, data['tgt'].shape,
                                np.sum(cum_tokens), end_time - start_time))
                        start_time = time.time()
                        cum_tokens = []

                    # trigger model saver
                    if gstep > 0 and gstep % params.save_freq == 0:
                        train_saver.save(sess, gstep)
                        params.recorder.save_to_json(
                            os.path.join(params.output_dir, "record.json"))

                    # trigger model evaluation
                    if gstep > 0 and gstep % params.eval_freq == 0:
                        if params.ema_decay > 0.:
                            sess.run(ops['ema_backup_op'])
                            sess.run(ops['ema_assign_op'])

                        tf.logging.info("Start Evaluating")
                        eval_start_time = time.time()
                        tranes, scores, indices = evalu.decoding(
                            sess, features, eval_seqs, eval_scores,
                            dev_dataset, params)
                        bleu = evalu.eval_metric(tranes,
                                                 params.tgt_dev_file,
                                                 indices=indices)
                        eval_end_time = time.time()
                        tf.logging.info("End Evaluating")

                        if params.ema_decay > 0.:
                            sess.run(ops['ema_restore_op'])

                        tf.logging.info(
                            "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s"
                            .format(util.time_str(eval_end_time), gstep,
                                    np.mean(scores), bleu,
                                    eval_end_time - eval_start_time))

                        # save eval translation
                        evalu.dump_tanslation(
                            tranes,
                            os.path.join(params.output_dir,
                                         "eval-{}.trans.txt".format(gstep)),
                            indices=indices)

                        # save parameters
                        train_saver.save(sess, gstep, bleu)

                        # check for early stopping
                        valid_scores = [
                            v[1] for v in params.recorder.valid_script_scores
                        ]
                        if len(valid_scores
                               ) == 0 or bleu > np.max(valid_scores):
                            params.recorder.bad_counter = 0
                        else:
                            params.recorder.bad_counter += 1

                            if params.recorder.bad_counter > params.estop_patience:
                                params.recorder.estop = True
                                break

                        params.recorder.history_scores.append(
                            (int(gstep), float(np.mean(scores))))
                        params.recorder.valid_script_scores.append(
                            (int(gstep), float(bleu)))
                        params.recorder.save_to_json(
                            os.path.join(params.output_dir, "record.json"))

                        # handle the learning rate decay in a typical manner
                        adapt_lr.after_eval(float(bleu))

                    # trigger temporary sampling
                    if gstep > 0 and gstep % params.sample_freq == 0:
                        tf.logging.info("Start Sampling")
                        decode_seqs, decode_scores = sess.run(
                            [eval_seqs[:1], eval_scores[:1]],
                            feed_dict={features[0]["source"]: data["src"][:5]})
                        tranes, scores = evalu.decode_hypothesis(
                            decode_seqs, decode_scores, params)

                        for sidx in range(min(5, len(scores))):
                            sample_source = evalu.decode_target_token(
                                data['src'][sidx], params.src_vocab)
                            tf.logging.info("{}-th Source: {}".format(
                                sidx, ' '.join(sample_source)))
                            sample_target = evalu.decode_target_token(
                                data['tgt'][sidx], params.tgt_vocab)
                            tf.logging.info("{}-th Target: {}".format(
                                sidx, ' '.join(sample_target)))
                            sample_trans = tranes[sidx]
                            tf.logging.info("{}-th Translation: {}".format(
                                sidx, ' '.join(sample_trans)))

                        tf.logging.info("End Sampling")

                    # trigger stopping
                    if gstep >= params.max_training_steps:
                        # stop running by setting EStop signal
                        params.recorder.estop = True
                        break

                    # should be equal to global_step
                    params.recorder.step = int(gstep)

            if params.recorder.estop:
                tf.logging.info("Early Stopped!")
                break

            # reset to 0
            params.recorder.lidx = -1

            adapt_lr.after_epoch(eidx=epoch)

    # Final Evaluation
    tf.logging.info("Start Final Evaluating")
    if params.ema_decay > 0.:
        sess.run(ops['ema_backup_op'])
        sess.run(ops['ema_assign_op'])

    gstep = int(params.recorder.step + 1)
    eval_start_time = time.time()
    tranes, scores, indices = evalu.decoding(sess, features, eval_seqs,
                                             eval_scores, dev_dataset, params)
    bleu = evalu.eval_metric(tranes, params.tgt_dev_file, indices=indices)
    eval_end_time = time.time()
    tf.logging.info("End Evaluating")

    if params.ema_decay > 0.:
        sess.run(ops['ema_restore_op'])

    tf.logging.info(
        "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s".format(
            util.time_str(eval_end_time), gstep, np.mean(scores), bleu,
            eval_end_time - eval_start_time))

    # save eval translation
    evalu.dump_tanslation(tranes,
                          os.path.join(params.output_dir,
                                       "eval-{}.trans.txt".format(gstep)),
                          indices=indices)

    tf.logging.info("Your training is finished :)")

    return train_saver.best_score