Esempio n. 1
0
def model_size(config):
    logger = logging.getLogger('')

    config.train.num_gpus = 1
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables in disk.
        if tf.train.latest_checkpoint(config.model_dir):
            available_vars = available_variables(config.model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(sess, tf.train.latest_checkpoint(config.model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
                logger.info('=================================')
                import example.ctc.ctc_util as ctc_util
                logger.info(ctc_util.print_nnet_info())
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')
Esempio n. 2
0
def train(config):
    logger = logging.getLogger('')
    """Train a model with a config file."""
    data_reader = DataReader(config=config)
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(config.model_dir, graph=model.graph)

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables in disk.
        if tf.train.latest_checkpoint(config.model_dir):
            available_vars = available_variables(config.model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, data_reader)

        global dev_bleu, toleration
        dev_bleu = evaluator.evaluate(
            **config.dev) if config.train.eval_on_dev else 0
        toleration = config.train.toleration

        def train_one_step(batch):
            feat_batch, target_batch, batch_size = batch
            feed_dict = expand_feed_dict({
                model.src_pls: feat_batch,
                model.dst_pls: target_batch
            })
            step, lr, loss, _ = sess.run([
                model.global_step, model.learning_rate, model.loss,
                model.train_op
            ],
                                         feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
            return step, lr, loss

        def maybe_save_model():
            global dev_bleu, toleration
            new_dev_bleu = evaluator.evaluate(
                **config.dev) if config.train.eval_on_dev else dev_bleu + 1
            if new_dev_bleu >= dev_bleu:
                mp = config.model_dir + '/model_step_{}'.format(step)
                model.saver.save(sess, mp)
                logger.info('Save model in %s.' % mp)
                toleration = config.train.toleration
                dev_bleu = new_dev_bleu
            else:
                toleration -= 1

        step = 0
        for epoch in range(1, config.train.num_epochs + 1):
            for batch in data_reader.get_training_batches_with_buckets():

                # Train normal instances.
                start_time = time.time()
                step, lr, loss = train_one_step(batch)
                logger.info(
                    'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}\tbatch_size: {5}'
                    .format(epoch, step, lr, loss,
                            time.time() - start_time, batch[2]))
                # Save model
                if config.train.save_freq > 0 and step % config.train.save_freq == 0:
                    maybe_save_model()

                if config.train.num_steps and step >= config.train.num_steps:
                    break

            # Save model per epoch if config.train.save_freq is less or equal than zero
            if config.train.save_freq <= 0:
                maybe_save_model()

            # Early stop
            if toleration <= 0:
                break
        logger.info("Finish training.")
Esempio n. 3
0
def train(config, num_epoch, last_pretrain_model_dir, pretrain_model_dir,
          model_dir, block_idx):
    logger = logging.getLogger('')
    config.num_blocks = block_idx
    # if block_idx >= 2:
    #     config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str(
    #         block_idx - 1) + '|' + 'encoder/src_embedding' + '|' + 'decoder/dst_embedding'
    # if block_idx >= 2:
    #     config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str(
    #         block_idx - 1)
    logger.info("config.num_blocks=" + str(config.num_blocks) +
                ',config.train.var_filter=' + str(config.train.var_filter))
    """Train a model with a config file."""
    data_reader = DataReader(config=config)
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(pretrain_model_dir,
                                           graph=model.graph)

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables in disk.
        if tf.train.latest_checkpoint(last_pretrain_model_dir):
            # available_vars = available_variables_without_global_step(last_pretrain_model_dir)
            available_vars = available_variables(last_pretrain_model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(
                    sess, tf.train.latest_checkpoint(last_pretrain_model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, data_reader)

        global dev_bleu, toleration
        dev_bleu = evaluator.evaluate(
            **config.dev) if config.train.eval_on_dev else 0
        toleration = config.train.toleration

        def train_one_step(batch):
            feat_batch, target_batch = batch
            feed_dict = expand_feed_dict({
                model.src_pls: feat_batch,
                model.dst_pls: target_batch
            })
            step, lr, loss, _ = sess.run([
                model.global_step, model.learning_rate, model.loss,
                model.train_op
            ],
                                         feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                logger.info('pretrain summary_writer...')
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
                summary_writer.flush()
            return step, lr, loss

        def maybe_save_model(model_dir, is_save_global_step=True):
            global dev_bleu, toleration
            new_dev_bleu = evaluator.evaluate(
                **config.dev) if config.train.eval_on_dev else dev_bleu + 1
            if new_dev_bleu >= dev_bleu:
                mp = model_dir + '/pretrain_model_step_{}'.format(step)

                # model.saver.save(sess, mp)
                if is_save_global_step:
                    model.saver.save(sess, mp)
                else:
                    variables_without_global_step = global_variables_without_global_step(
                    )
                    saver = tf.train.Saver(
                        var_list=variables_without_global_step, max_to_keep=10)
                    saver.save(sess, mp)

                logger.info('Save model in %s.' % mp)
                toleration = config.train.toleration
                dev_bleu = new_dev_bleu
            else:
                toleration -= 1

        step = 0
        for epoch in range(1, num_epoch + 1):
            for batch in data_reader.get_training_batches_with_buckets():
                # Train normal instances.
                start_time = time.time()
                step, lr, loss = train_one_step(batch)
                logger.info(
                    'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}'
                    .format(epoch, step, lr, loss,
                            time.time() - start_time))

                if config.train.num_steps and step >= config.train.num_steps:
                    break

            # Early stop
            if toleration <= 0:
                break

        maybe_save_model(pretrain_model_dir)
        if model_dir:
            maybe_save_model(model_dir, False)
        logger.info("Finish pretrain layer=" + str(block_idx))
Esempio n. 4
0
def train(config):
    """Train a model with a config file."""
    logger = logging.getLogger('')
    data_reader = DataReader(config=config)
    model = eval(config.model)(config=config, num_gpus=config.train.num_gpus)
    model.build_train_model(test=config.train.eval_on_dev)

    train_op, loss_op = model.get_train_op(name=None)
    global_saver = tf.train.Saver()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(config.model_dir)

    with tf.Session(config=sess_config) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())
        # Reload variables from disk.
        if tf.train.latest_checkpoint(config.model_dir):
            available_vars = available_variables(config.model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(sess, tf.train.latest_checkpoint(config.model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        evaluator = Evaluator()
        evaluator.init_from_existed(model, sess, data_reader)

        global dev_bleu, toleration
        dev_bleu = evaluator.evaluate(**config.dev) if config.train.eval_on_dev else 0
        toleration = config.train.toleration

        def train_one_step(batch, loss_op, train_op):
            feed_dict = expand_feed_dict({model.src_pls: batch[0], model.dst_pls: batch[1]})
            step, lr, loss, _ = sess.run(
                [model.global_step, model.learning_rate,
                 loss_op, train_op],
                feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
            return step, lr, loss

        def maybe_save_model():
            global dev_bleu, toleration

            def save():
                mp = config.model_dir + '/model_step_{}'.format(step)
                global_saver.save(sess, mp)
                logger.info('Save model in %s.' % mp)

            if config.train.eval_on_dev:
                new_dev_bleu = evaluator.evaluate(**config.dev)

                summary = tf.Summary(value=[tf.Summary.Value(tag="dev_bleu",
                                                             simple_value=new_dev_bleu)])

                summary_writer.add_summary(summary, step)

                if config.train.toleration is None:
                    save()
                else:
                    if new_dev_bleu >= dev_bleu:
                        save()
                        toleration = config.train.toleration
                        dev_bleu = new_dev_bleu
                    else:
                        toleration -= 1
            else:
                save()

        try:
            step = 0
            for epoch in range(1, config.train.num_epochs+1):
                for batch in data_reader.get_training_batches(epoches=1):

                    # Train normal instances.
                    start_time = time.time()
                    step, lr, loss = train_one_step(batch, loss_op, train_op)
                    logger.info(
                        'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}'.
                        format(epoch, step, lr, loss, time.time() - start_time))
                    # Save model
                    if config.train.save_freq > 0 \
                       and step > 0 \
                       and step % config.train.save_freq == 0:
                        maybe_save_model()

                    if config.train.num_steps is not None and step >= config.train.num_steps:
                        raise BreakLoopException("BreakLoop")

                    if toleration is not None and toleration <= 0:
                        raise BreakLoopException("BreakLoop")

                # Save model per epoch if config.train.save_freq is less or equal than zero
                if config.train.save_freq <= 0:
                    maybe_save_model()
        except BreakLoopException as e:
            logger.info(e)

        logger.info("Finish training.")
Esempio n. 5
0
def train(config):
    logger = logging.getLogger('')
    """Train a model with a config file."""
    train_graph = tf.Graph()
    print(config.train.tfrecord_pattern)
    data_files = tf.gfile.Glob(config.train.tfrecord_pattern)
    logging.info("Find {} tfrecords files".format(len(data_files)))
    with train_graph.as_default():
        data_holder = tf.placeholder(tf.string, shape=[None])
        dataset = tf.data.TFRecordDataset(
            data_holder, num_parallel_reads=config.train.read_threads)
        dataset = dataset.map(parse_function_var,
                              num_parallel_calls=config.train.read_threads)
        shuffle_data = True
        if shuffle_data is True:
            dataset = dataset.shuffle(buffer_size=10000)
        dataset = dataset.repeat(config.train.num_epochs).batch(
            config.train.batchsize_read)

        iterator = dataset.make_initializable_iterator()

        feat_shape_tensor, feat_tensor, label_shape_tensor, label_tensor = iterator.get_next(
        )

        feat_tensor = tf.sparse_tensor_to_dense(feat_tensor)
        label_tensor = tf.sparse_tensor_to_dense(label_tensor)
        label_tensor = tf.cast(label_tensor, tf.int32)
        feat_tensor_shapeop = tf.shape(feat_tensor)
        feat_tensor = tf.reshape(
            feat_tensor, [feat_tensor_shapeop[0], -1, config.train.input_dim])

    model = eval(config.model)(config=config,
                               num_gpus=config.train.num_gpus,
                               X=feat_tensor,
                               Y=label_tensor,
                               tensor_graph=train_graph)
    model.build_train_model(test=config.train.eval_on_dev)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True

    summary_writer = tf.summary.FileWriter(config.model_dir, graph=model.graph)

    with tf.Session(config=sess_config, graph=model.graph) as sess:
        # Initialize all variables.
        sess.run(tf.global_variables_initializer())

        sess.run(iterator.initializer, feed_dict={data_holder: data_files})
        # Reload variables in disk.
        if tf.train.latest_checkpoint(config.model_dir):
            available_vars = available_variables(config.model_dir)
            if available_vars:
                saver = tf.train.Saver(var_list=available_vars)
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.model_dir))
                for v in available_vars:
                    logger.info('Reload {} from disk.'.format(v.name))
            else:
                logger.info('Nothing to be reload from disk.')
        else:
            logger.info('Nothing to be reload from disk.')

        global dev_bleu, toleration
        dev_bleu = 0
        toleration = config.train.toleration

        def train_one_step(batch):
            feat_batch, target_batch, batch_size = batch
            feed_dict = expand_feed_dict({
                model.src_pls: feat_batch,
                model.dst_pls: target_batch
            })
            step, lr, loss, _ = sess.run([
                model.global_step, model.learning_rate, model.loss,
                model.train_op
            ],
                                         feed_dict=feed_dict)
            if step % config.train.summary_freq == 0:
                summary = sess.run(model.summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary, global_step=step)
            return step, lr, loss

        def maybe_save_model():
            global dev_bleu, toleration
            new_dev_bleu = dev_bleu + 1
            if new_dev_bleu >= dev_bleu:
                mp = config.model_dir + '/model_step_{}'.format(step)
                model.saver.save(sess, mp)
                logger.info('Save model in %s.' % mp)
                toleration = config.train.toleration
                dev_bleu = new_dev_bleu
            else:
                toleration -= 1

        step = 0
        while True:
            try:
                pre_train_time = time.time()
                feat_shape, feat, label_shape, label = sess.run([
                    feat_shape_tensor, feat_tensor, label_shape_tensor,
                    label_tensor
                ])
                batch = (feat, label, feat.shape[0])
                #logging.info("This batch has {} samples".format(feat.shape[0]))
                #logging.info("The feat shape is {}".format(feat.shape))
                # Train normal instances.
                start_time = time.time()
                step, lr, loss = train_one_step(batch)
                logger.info(
                    'step: {0}\tlr: {1:.6f}\tloss: {2:.4f}\ttrain_time: {3:.4f}\tpre_train_time: {4:.5f}\tbatch_size: {5}'
                    .format(step, lr, loss,
                            time.time() - start_time,
                            start_time - pre_train_time, batch[2]))
                # Save model
                pre_train_time = time.time()
                if config.train.save_freq > 0 and step % config.train.save_freq == 0:
                    maybe_save_model()

                if config.train.num_steps and step >= config.train.num_steps:
                    break

                # Save model per epoch if config.train.save_freq is less or equal than zero
                if config.train.save_freq <= 0:
                    maybe_save_model()

                # Early stop
                if toleration <= 0:
                    break
            except tf.errors.OutOfRangeError:
                logging.info("All data done!")
        logger.info("Finish training.")