Exemple #1
0
        def maybe_save_model(config):
            mp = config.model_dir + '/model_epoch_{}'.format(epoch)
            model.saver.save(sess, mp)
            logger.info('Save model in %s.' % mp)

            if config.train.eval_on_dev:
                evaluator = Evaluator(config)
                evaluator.init_from_existed(config, model, sess)
                evaluator.translate(
                    config.dev.feat_file_pattern, config.dev.output_file +
                    'decode_result_epoch_' + '{}'.format(str(epoch)))
Exemple #2
0
def train(config):
    logger = logging.getLogger('')
    """Train a model with a config file."""
    du = DataUtil(config=config)
    model = Transformer(config=config, devices=config.train.devices)
    model.build_train_model()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(config.train.logdir,
                                           graph=model.graph)

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        try:
            model.saver.restore(
                sess, tf.train.latest_checkpoint(config.train.logdir))
        except:
            # Initialize all variables.
            sess.run(tf.global_variables_initializer())
            logger.info('Failed to reload model.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, du)

        dev_bleu = evaluator.evaluate(**config.dev)
        toleration = config.train.toleration
        for epoch in range(1, config.train.num_epochs + 1):
            for batch in du.get_training_batches_with_buckets():
                start_time = time.time()
                step = sess.run(model.global_step)
                # Summary
                if step % config.train.summary_freq == 0:
                    step, lr, gnorm, loss, acc, summary, _ = sess.run(
                        [
                            model.global_step, model.learning_rate,
                            model.grads_norm, model.loss, model.accuracy,
                            model.summary_op, model.train_op
                        ],
                        feed_dict={
                            model.src_pl: batch[0],
                            model.dst_pl: batch[1]
                        })
                    summary_writer.add_summary(summary, global_step=step)
                else:
                    step, lr, gnorm, loss, acc, _ = sess.run([
                        model.global_step, model.learning_rate,
                        model.grads_norm, model.loss, model.accuracy,
                        model.train_op
                    ],
                                                             feed_dict={
                                                                 model.src_pl:
                                                                 batch[0],
                                                                 model.dst_pl:
                                                                 batch[1]
                                                             })
                logger.info(
                    'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tgnorm: {3:.4f}\tloss: {4:.4f}\tacc: {5:.4f}\ttime: {6:.4f}'
                    .format(epoch, step, lr, gnorm, loss, acc,
                            time.time() - start_time))

                # Save model
                if config.train.save_freq > 0 and step % config.train.save_freq == 0:
                    new_dev_bleu = evaluator.evaluate(**config.dev)
                    if new_dev_bleu >= dev_bleu:
                        mp = config.train.logdir + '/model_epoch_%d_step_%d' % (
                            epoch, step)
                        model.saver.save(sess, mp)
                        logger.info('Save model in %s.' % mp)
                        toleration = config.train.toleration
                        dev_bleu = new_dev_bleu
                    else:
                        toleration -= 1
                        if toleration <= 0:
                            break

            # Save model per epoch if config.train.save_freq is less than zero
            if config.train.save_freq <= 0:
                new_dev_bleu = evaluator.evaluate(**config.dev)
                if new_dev_bleu >= dev_bleu:
                    mp = config.train.logdir + '/model_epoch_%d' % (epoch)
                    model.saver.save(sess, mp)
                    logger.info('Save model in %s.' % mp)
                    toleration = config.train.toleration
                    dev_bleu = new_dev_bleu
                else:
                    toleration -= 1
                    if toleration <= 0:
                        break

            if toleration <= 0:
                break
        logger.info("Finish training.")
Exemple #3
0
def train(config):
    logger = logging.getLogger('')
    """Train a model with a config file."""
    data_reader = DataReader(config=config)
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(config.model_dir, graph=model.graph)

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables in disk.
        if tf.train.latest_checkpoint(config.model_dir):
            available_vars = available_variables(config.model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, data_reader)

        global dev_bleu, toleration
        dev_bleu = evaluator.evaluate(
            **config.dev) if config.train.eval_on_dev else 0
        toleration = config.train.toleration

        def train_one_step(batch):
            feat_batch, target_batch, batch_size = batch
            feed_dict = expand_feed_dict({
                model.src_pls: feat_batch,
                model.dst_pls: target_batch
            })
            step, lr, loss, _ = sess.run([
                model.global_step, model.learning_rate, model.loss,
                model.train_op
            ],
                                         feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
            return step, lr, loss

        def maybe_save_model():
            global dev_bleu, toleration
            new_dev_bleu = evaluator.evaluate(
                **config.dev) if config.train.eval_on_dev else dev_bleu + 1
            if new_dev_bleu >= dev_bleu:
                mp = config.model_dir + '/model_step_{}'.format(step)
                model.saver.save(sess, mp)
                logger.info('Save model in %s.' % mp)
                toleration = config.train.toleration
                dev_bleu = new_dev_bleu
            else:
                toleration -= 1

        step = 0
        for epoch in range(1, config.train.num_epochs + 1):
            for batch in data_reader.get_training_batches_with_buckets():

                # Train normal instances.
                start_time = time.time()
                step, lr, loss = train_one_step(batch)
                logger.info(
                    'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}\tbatch_size: {5}'
                    .format(epoch, step, lr, loss,
                            time.time() - start_time, batch[2]))
                # Save model
                if config.train.save_freq > 0 and step % config.train.save_freq == 0:
                    maybe_save_model()

                if config.train.num_steps and step >= config.train.num_steps:
                    break

            # Save model per epoch if config.train.save_freq is less or equal than zero
            if config.train.save_freq <= 0:
                maybe_save_model()

            # Early stop
            if toleration <= 0:
                break
        logger.info("Finish training.")
Exemple #4
0
def train(config):
    """Train a model with a config file."""
    logger = logging.getLogger('')
    data_reader = DataReader(config=config)
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    train_op, loss_op = model.get_train_op(name=None)
    global_saver = tf.train.Saver()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(config.model_dir)

    with tf.Session(config=sess_config) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables from disk.
        if tf.train.latest_checkpoint(config.model_dir):
            available_vars = available_variables(config.model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(sess, tf.train.latest_checkpoint(config.model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, data_reader)

        global dev_bleu, toleration
        dev_bleu = evaluator.evaluate(**config.dev) if config.train.eval_on_dev else 0
        toleration = config.train.toleration

        def train_one_step(batch, loss_op, train_op):
            feed_dict = expand_feed_dict({model.src_pls: batch[0], model.dst_pls: batch[1]})
            step, lr, loss, _ = sess.run(
                [model.global_step, model.learning_rate,
                 loss_op, train_op],
                feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
            return step, lr, loss

        def maybe_save_model():
            global dev_bleu, toleration

            def save():
                mp = config.model_dir + '/model_step_{}'.format(step)
                global_saver.save(sess, mp)
                logger.info('Save model in %s.' % mp)

            if config.train.eval_on_dev:
                new_dev_bleu = evaluator.evaluate(**config.dev)

                summary = tf.Summary(value=[tf.Summary.Value(tag="dev_bleu",
                                                             simple_value=new_dev_bleu)])

                summary_writer.add_summary(summary, step)

                if config.train.toleration is None:
                    save()
                else:
                    if new_dev_bleu >= dev_bleu:
                        save()
                        toleration = config.train.toleration
                        dev_bleu = new_dev_bleu
                    else:
                        toleration -= 1
            else:
                save()

        try:
            step = 0
            for epoch in range(1, config.train.num_epochs+1):
                for batch in data_reader.get_training_batches(epoches=1):

                    # Train normal instances.
                    start_time = time.time()
                    step, lr, loss = train_one_step(batch, loss_op, train_op)
                    logger.info(
                        'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}'.
                        format(epoch, step, lr, loss, time.time() - start_time))
                    # Save model
                    if config.train.save_freq > 0 \
                       and step > 0 \
                       and step % config.train.save_freq == 0:
                        maybe_save_model()

                    if config.train.num_steps is not None and step >= config.train.num_steps:
                        raise BreakLoopException("BreakLoop")

                    if toleration is not None and toleration <= 0:
                        raise BreakLoopException("BreakLoop")

                # Save model per epoch if config.train.save_freq is less or equal than zero
                if config.train.save_freq <= 0:
                    maybe_save_model()
        except BreakLoopException as e:
            logger.info(e)

        logger.info("Finish training.")
def train(config, num_epoch, last_pretrain_model_dir, pretrain_model_dir,
          model_dir, block_idx_enc, block_idx_dec):
    logger = logging.getLogger('')
    config.num_blocks_enc = block_idx_enc
    config.num_blocks_dec = block_idx_dec
    # if block_idx >= 2:
    #     config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str(
    #         block_idx - 1) + '|' + 'encoder/src_embedding' + '|' + 'decoder/dst_embedding'
    # if block_idx >= 2:
    #     config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str(
    #         block_idx - 1)
    logger.info("config.num_blocks_enc=" + str(config.num_blocks_enc) +
                ",config.num_blocks_dec=" + str(config.num_blocks_dec) +
                ',config.train.var_filter=' + str(config.train.var_filter))
    """Train a model with a config file."""
    data_reader = DataReader(config=config)
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(pretrain_model_dir,
                                           graph=model.graph)

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables in disk.
        if tf.train.latest_checkpoint(last_pretrain_model_dir):
            available_vars = available_variables_without_global_step(
                last_pretrain_model_dir)
            # available_vars = available_variables(last_pretrain_model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(
                    sess, tf.train.latest_checkpoint(last_pretrain_model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, data_reader)

        global dev_bleu, toleration
        dev_bleu = evaluator.evaluate(
            **config.dev) if config.train.eval_on_dev else 0
        toleration = config.train.toleration

        def train_one_step(batch):
            feat_batch, target_batch = batch
            feed_dict = expand_feed_dict({
                model.src_pls: feat_batch,
                model.dst_pls: target_batch
            })
            step, lr, loss, _ = sess.run([
                model.global_step, model.learning_rate, model.loss,
                model.train_op
            ],
                                         feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                logger.info('pretrain summary_writer...')
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
                summary_writer.flush()
            return step, lr, loss

        def maybe_save_model(model_dir, is_save_global_step=True):
            global dev_bleu, toleration
            new_dev_bleu = evaluator.evaluate(
                **config.dev) if config.train.eval_on_dev else dev_bleu + 1
            if new_dev_bleu >= dev_bleu:
                mp = model_dir + '/pretrain_model_step_{}'.format(step)

                # model.saver.save(sess, mp)
                if is_save_global_step:
                    model.saver.save(sess, mp)
                else:
                    variables_without_global_step = global_variables_without_global_step(
                    )
                    saver = tf.train.Saver(
                        var_list=variables_without_global_step, max_to_keep=10)
                    saver.save(sess, mp)

                logger.info('Save model in %s.' % mp)
                toleration = config.train.toleration
                dev_bleu = new_dev_bleu
            else:
                toleration -= 1

        step = 0
        for epoch in range(1, num_epoch + 1):
            for batch in data_reader.get_training_batches_with_buckets():
                # Train normal instances.
                start_time = time.time()
                step, lr, loss = train_one_step(batch)
                logger.info(
                    'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}'
                    .format(epoch, step, lr, loss,
                            time.time() - start_time))

                if config.train.num_steps and step >= config.train.num_steps:
                    break

            # Early stop
            if toleration <= 0:
                break

        maybe_save_model(pretrain_model_dir)
        if model_dir:
            maybe_save_model(model_dir, False)
        logger.info("Finish pretrain block_idx_enc=" + str(block_idx_enc) +
                    ',block_idx_dec=' + str(block_idx_dec))