Ejemplo n.º 1
0
def train(config):
    logger = logging.getLogger('')
    """Train a model with a config file."""
    data_reader = DataReader(config=config)
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(config.model_dir, graph=model.graph)

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables in disk.
        if tf.train.latest_checkpoint(config.model_dir):
            available_vars = available_variables(config.model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, data_reader)

        global dev_bleu, toleration
        dev_bleu = evaluator.evaluate(
            **config.dev) if config.train.eval_on_dev else 0
        toleration = config.train.toleration

        def train_one_step(batch):
            feat_batch, target_batch, batch_size = batch
            feed_dict = expand_feed_dict({
                model.src_pls: feat_batch,
                model.dst_pls: target_batch
            })
            step, lr, loss, _ = sess.run([
                model.global_step, model.learning_rate, model.loss,
                model.train_op
            ],
                                         feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
            return step, lr, loss

        def maybe_save_model():
            global dev_bleu, toleration
            new_dev_bleu = evaluator.evaluate(
                **config.dev) if config.train.eval_on_dev else dev_bleu + 1
            if new_dev_bleu >= dev_bleu:
                mp = config.model_dir + '/model_step_{}'.format(step)
                model.saver.save(sess, mp)
                logger.info('Save model in %s.' % mp)
                toleration = config.train.toleration
                dev_bleu = new_dev_bleu
            else:
                toleration -= 1

        step = 0
        for epoch in range(1, config.train.num_epochs + 1):
            for batch in data_reader.get_training_batches_with_buckets():

                # Train normal instances.
                start_time = time.time()
                step, lr, loss = train_one_step(batch)
                logger.info(
                    'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}\tbatch_size: {5}'
                    .format(epoch, step, lr, loss,
                            time.time() - start_time, batch[2]))
                # Save model
                if config.train.save_freq > 0 and step % config.train.save_freq == 0:
                    maybe_save_model()

                if config.train.num_steps and step >= config.train.num_steps:
                    break

            # Save model per epoch if config.train.save_freq is less or equal than zero
            if config.train.save_freq <= 0:
                maybe_save_model()

            # Early stop
            if toleration <= 0:
                break
        logger.info("Finish training.")
Ejemplo n.º 2
0
def train(config, num_epoch, last_pretrain_model_dir, pretrain_model_dir,
          model_dir, block_idx_enc, block_idx_dec):
    logger = logging.getLogger('')
    config.num_blocks_enc = block_idx_enc
    config.num_blocks_dec = block_idx_dec
    # if block_idx >= 2:
    #     config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str(
    #         block_idx - 1) + '|' + 'encoder/src_embedding' + '|' + 'decoder/dst_embedding'
    # if block_idx >= 2:
    #     config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str(
    #         block_idx - 1)
    logger.info("config.num_blocks_enc=" + str(config.num_blocks_enc) +
                ",config.num_blocks_dec=" + str(config.num_blocks_dec) +
                ',config.train.var_filter=' + str(config.train.var_filter))
    """Train a model with a config file."""
    data_reader = DataReader(config=config)
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(pretrain_model_dir,
                                           graph=model.graph)

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables in disk.
        if tf.train.latest_checkpoint(last_pretrain_model_dir):
            available_vars = available_variables_without_global_step(
                last_pretrain_model_dir)
            # available_vars = available_variables(last_pretrain_model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(
                    sess, tf.train.latest_checkpoint(last_pretrain_model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, data_reader)

        global dev_bleu, toleration
        dev_bleu = evaluator.evaluate(
            **config.dev) if config.train.eval_on_dev else 0
        toleration = config.train.toleration

        def train_one_step(batch):
            feat_batch, target_batch = batch
            feed_dict = expand_feed_dict({
                model.src_pls: feat_batch,
                model.dst_pls: target_batch
            })
            step, lr, loss, _ = sess.run([
                model.global_step, model.learning_rate, model.loss,
                model.train_op
            ],
                                         feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                logger.info('pretrain summary_writer...')
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
                summary_writer.flush()
            return step, lr, loss

        def maybe_save_model(model_dir, is_save_global_step=True):
            global dev_bleu, toleration
            new_dev_bleu = evaluator.evaluate(
                **config.dev) if config.train.eval_on_dev else dev_bleu + 1
            if new_dev_bleu >= dev_bleu:
                mp = model_dir + '/pretrain_model_step_{}'.format(step)

                # model.saver.save(sess, mp)
                if is_save_global_step:
                    model.saver.save(sess, mp)
                else:
                    variables_without_global_step = global_variables_without_global_step(
                    )
                    saver = tf.train.Saver(
                        var_list=variables_without_global_step, max_to_keep=10)
                    saver.save(sess, mp)

                logger.info('Save model in %s.' % mp)
                toleration = config.train.toleration
                dev_bleu = new_dev_bleu
            else:
                toleration -= 1

        step = 0
        for epoch in range(1, num_epoch + 1):
            for batch in data_reader.get_training_batches_with_buckets():
                # Train normal instances.
                start_time = time.time()
                step, lr, loss = train_one_step(batch)
                logger.info(
                    'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}'
                    .format(epoch, step, lr, loss,
                            time.time() - start_time))

                if config.train.num_steps and step >= config.train.num_steps:
                    break

            # Early stop
            if toleration <= 0:
                break

        maybe_save_model(pretrain_model_dir)
        if model_dir:
            maybe_save_model(model_dir, False)
        logger.info("Finish pretrain block_idx_enc=" + str(block_idx_enc) +
                    ',block_idx_dec=' + str(block_idx_dec))