예제 #1
0
파일: train.py 프로젝트: watanka/SATRN
def main(config_file):
    """ Train text recognition network
    """
    # Parse configs
    FLAGS = Flags(config_file).get()

    # Set directory, seed, logger
    model_dir = create_model_dir(FLAGS.model_dir)
    logger = get_logger(model_dir, 'train')
    best_model_dir = os.path.join(model_dir, 'best_models')
    set_seed(FLAGS.seed)

    # Print configs
    flag_strs = [
        '{}:\t{}'.format(name, value)
        for name, value in FLAGS._asdict().items()
    ]
    log_formatted(logger, '[+] Model configurations', *flag_strs)

    # Print system environments
    num_gpus = count_available_gpus()
    num_cpus = os.cpu_count()
    mem_size = virtual_memory().available // (1024**3)
    log_formatted(logger, '[+] System environments',
                  'The number of gpus : {}'.format(num_gpus),
                  'The number of cpus : {}'.format(num_cpus),
                  'Memory Size : {}G'.format(mem_size))

    # Get optimizer and network
    global_step = tf.train.get_or_create_global_step()
    optimizer, learning_rate = get_optimizer(FLAGS.train.optimizer,
                                             global_step)
    out_charset = load_charset(FLAGS.charset)
    net = get_network(FLAGS, out_charset)
    is_ctc = (net.loss_fn == 'ctc_loss')

    # Multi tower for multi-gpu training
    tower_grads = []
    tower_extra_update_ops = []
    tower_preds = []
    tower_gts = []
    tower_losses = []
    batch_size = FLAGS.train.batch_size
    tower_batch_size = batch_size // num_gpus

    val_tower_outputs = []
    eval_tower_outputs = []

    for gpu_indx in range(num_gpus):

        # Train tower
        print('[+] Build Train tower GPU:%d' % gpu_indx)
        input_device = '/gpu:%d' % gpu_indx

        tower_batch_size = tower_batch_size \
            if gpu_indx < num_gpus-1 \
            else batch_size - tower_batch_size * (num_gpus-1)

        train_loader = DatasetLodaer(
            dataset_paths=FLAGS.train.dataset_paths,
            dataset_portions=FLAGS.train.dataset_portions,
            batch_size=tower_batch_size,
            label_maxlen=FLAGS.label_maxlen,
            out_charset=out_charset,
            preprocess_image=net.preprocess_image,
            is_train=True,
            is_ctc=is_ctc,
            shuffle_and_repeat=True,
            concat_batch=True,
            input_device=input_device,
            num_cpus=num_cpus,
            num_gpus=num_gpus,
            worker_index=gpu_indx,
            use_rgb=FLAGS.use_rgb,
            seed=FLAGS.seed,
            name='train')

        tower_output = single_tower(net,
                                    gpu_indx,
                                    train_loader,
                                    out_charset,
                                    optimizer,
                                    name='train',
                                    is_train=True)
        tower_grads.append([x for x in tower_output.grads if x[0] is not None])
        tower_extra_update_ops.append(tower_output.extra_update_ops)
        tower_preds.append(tower_output.prediction)
        tower_gts.append(tower_output.text)
        tower_losses.append(tower_output.loss)

        # Print network structure
        if gpu_indx == 0:
            param_stats = tf.profiler.profile(tf.get_default_graph())
            logger.info('total_params: %d\n' % param_stats.total_parameters)

        # Valid tower
        print('[+] Build Valid tower GPU:%d' % gpu_indx)
        valid_loader = DatasetLodaer(dataset_paths=FLAGS.valid.dataset_paths,
                                     dataset_portions=None,
                                     batch_size=FLAGS.valid.batch_size //
                                     num_gpus,
                                     label_maxlen=FLAGS.label_maxlen,
                                     out_charset=out_charset,
                                     preprocess_image=net.preprocess_image,
                                     is_train=False,
                                     is_ctc=is_ctc,
                                     shuffle_and_repeat=False,
                                     concat_batch=False,
                                     input_device=input_device,
                                     num_cpus=num_cpus,
                                     num_gpus=num_gpus,
                                     worker_index=gpu_indx,
                                     use_rgb=FLAGS.use_rgb,
                                     seed=FLAGS.seed,
                                     name='valid')

        val_tower_output = single_tower(net,
                                        gpu_indx,
                                        valid_loader,
                                        out_charset,
                                        optimizer=None,
                                        name='valid',
                                        is_train=False)

        val_tower_outputs.append(
            (val_tower_output.loss, val_tower_output.prediction,
             val_tower_output.text, val_tower_output.filename,
             val_tower_output.dataset))

    # Aggregate gradients
    losses = tf.reduce_mean(tower_losses)
    grads = _average_gradients(tower_grads)

    with tf.control_dependencies(tower_extra_update_ops[-1]):
        if FLAGS.train.optimizer.grad_clip_norm is not None:
            grads, global_norm = _clip_gradients(
                grads, FLAGS.train.optimizer.grad_clip_norm)
            tf.summary.scalar('global_norm', global_norm)

        train_op = optimizer.apply_gradients(grads, global_step=global_step)

    # Define config, scaffold
    saver = tf.train.Saver()
    sess_config = get_session_config()
    scaffold = get_scaffold(saver, FLAGS.train.tune_from, 'train')
    restore_model = get_init_trained()

    # Define validation saver, summary writer
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    val_summary_op = tf.summary.merge(
        [s for s in summaries if 'valid' in s.name])
    val_summary_writer = {
        dataset_name:
        tf.summary.FileWriter(os.path.join(model_dir, 'valid', dataset_name))
        for dataset_name in valid_loader.dataset_names
    }
    val_summary_writer['total_valid'] = tf.summary.FileWriter(
        os.path.join(model_dir, 'valid', 'total_valid'))
    val_saver = tf.train.Saver(max_to_keep=len(valid_loader.dataset_names) + 1)
    best_val_err_rates = {}
    best_steps = {}

    # Training
    print('[+] Make Session...')

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=model_dir,
            scaffold=scaffold,
            config=sess_config,
            save_checkpoint_steps=FLAGS.train.save_steps,
            save_checkpoint_secs=None,
            save_summaries_steps=FLAGS.train.summary_steps,
            save_summaries_secs=None,
    ) as sess:

        log_formatted(logger, 'Training started!')
        _step = 0
        train_t = 0
        start_t = time.time()

        while _step < FLAGS.train.max_num_steps \
                and not sess.should_stop():

            # Train step
            step_t = time.time()
            [step_loss, _, _step, preds, gts, lr] = sess.run([
                losses, train_op, global_step, tower_preds[0], tower_gts[0],
                learning_rate
            ])
            train_t += time.time() - step_t

            # Summary
            if _step % FLAGS.valid.steps == 0:

                # Train summary
                train_err = 0.

                for i, (p, g) in enumerate(zip(preds, gts)):
                    s = get_string(p, out_charset, is_ctc=is_ctc)
                    g = g.decode('utf8').replace(DELIMITER, '')

                    s = adjust_string(s, FLAGS.train.lowercase,
                                      FLAGS.train.alphanumeric)
                    g = adjust_string(g, FLAGS.train.lowercase,
                                      FLAGS.train.alphanumeric)
                    e = int(s != g)

                    train_err += e

                    if FLAGS.train.verbose and i < 5:
                        print('TRAIN :\t{}\t{}\t{}'.format(s, g, not bool(e)))

                train_err_rate = \
                    train_err / len(gts)

                # Valid summary
                val_cnts, val_errs, val_err_rates, _ = \
                    validate(sess,
                             _step,
                             val_tower_outputs,
                             out_charset,
                             is_ctc,
                             val_summary_op,
                             val_summary_writer,
                             val_saver,
                             best_val_err_rates,
                             best_steps,
                             best_model_dir,
                             FLAGS.valid.lowercase,
                             FLAGS.valid.alphanumeric)

                # Logging
                log_strings = ['', '-' * 28 + ' VALID_DETAIL ' + '-' * 28, '']

                for dataset in sorted(val_err_rates.keys()):
                    if dataset == 'total_valid':
                        continue

                    cnt = val_cnts[dataset]
                    err = val_errs[dataset]
                    err_rate = val_err_rates[dataset]
                    best_step = best_steps[dataset]

                    s = '%s : %.2f%%(%d/%d)\tBEST_STEP : %d' % \
                        (dataset, (1.-err_rate)*100, cnt-err, cnt, best_step)

                    log_strings.append(s)

                elapsed_t = float(time.time() - start_t) / 60
                remain_t = (elapsed_t / (_step+1)) * \
                    (FLAGS.train.max_num_steps - _step - 1)
                log_formatted(
                    logger, 'STEP : %d\tTRAIN_LOSS : %f' % (_step, step_loss),
                    'ELAPSED : %.2f min\tREMAIN : %.2f min\t'
                    'STEP_TIME: %.1f sec' %
                    (elapsed_t, remain_t, float(train_t) / (_step + 1)),
                    'TRAIN_SEQ_ERR : %f\tVALID_SEQ_ERR : %f' %
                    (train_err_rate, val_err_rates['total_valid']),
                    'BEST_STEP : %d\tBEST_VALID_SEQ_ERR : %f' %
                    (best_steps['total_valid'],
                     best_val_err_rates['total_valid']), *log_strings)

        log_formatted(logger, 'Training is completed!')
예제 #2
0
def main(config_file=None):
    """ Run evaluation.
    """
    # Parse Config
    print('[+] Model configurations')
    FLAGS = Flags(config_file).get()
    for name, value in FLAGS._asdict().items():
        print('{}:\t{}'.format(name, value))
    print('\n')

    # System environments
    num_gpus = count_available_gpus()
    num_cpus = os.cpu_count()
    mem_size = virtual_memory().available // (1024**3)
    out_charset = load_charset(FLAGS.charset)
    print('[+] System environments')
    print('The number of gpus : {}'.format(num_gpus))
    print('The number of cpus : {}'.format(num_cpus))
    print('Memory Size : {}G'.format(mem_size))
    print('The number of characters : {}\n'.format(len(out_charset)))

    # Make results dir
    res_dir = os.path.join(FLAGS.eval.model_path)
    os.makedirs(res_dir, exist_ok=True)

    # Get network
    net = get_network(FLAGS, out_charset)
    is_ctc = (net.loss_fn == 'ctc_loss')

    # Define Graph
    eval_tower_outputs = []
    global_step = tf.train.get_or_create_global_step()

    for gpu_indx in range(num_gpus):
        # Get eval dataset
        input_device = '/gpu:%d' % gpu_indx
        print('[+] Build Eval tower GPU:%d' % gpu_indx)

        eval_loader = DatasetLodaer(dataset_paths=FLAGS.eval.dataset_paths,
                                    dataset_portions=None,
                                    batch_size=FLAGS.eval.batch_size,
                                    label_maxlen=FLAGS.label_maxlen,
                                    out_charset=out_charset,
                                    preprocess_image=net.preprocess_image,
                                    is_train=False,
                                    is_ctc=is_ctc,
                                    shuffle_and_repeat=False,
                                    concat_batch=False,
                                    input_device=input_device,
                                    num_cpus=num_cpus,
                                    num_gpus=num_gpus,
                                    worker_index=gpu_indx,
                                    use_rgb=FLAGS.use_rgb,
                                    seed=FLAGS.seed,
                                    name='eval')

        eval_tower_output = single_tower(net,
                                         gpu_indx,
                                         eval_loader,
                                         out_charset,
                                         optimizer=None,
                                         name='eval',
                                         is_train=False)

        eval_tower_outputs.append(
            (eval_tower_output.loss, eval_tower_output.prediction,
             eval_tower_output.text, eval_tower_output.filename,
             eval_tower_output.dataset))

    # Summary
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    summary_op = tf.summary.merge([s for s in summaries])
    summary_writer = {
        dataset_name:
        tf.summary.FileWriter(os.path.join(res_dir, dataset_name))
        for dataset_name in eval_loader.dataset_names
    }
    summary_writer['total_valid'] = tf.summary.FileWriter(
        os.path.join(res_dir, 'total_eval'))

    # Define config, scaffold, hooks
    saver = tf.train.Saver()
    sess_config = get_session_config()
    restore_model = get_init_trained()
    scaffold = get_scaffold(saver, None, 'eval')

    # Testing
    with tf.train.MonitoredTrainingSession(scaffold=scaffold,
                                           config=sess_config) as sess:

        # Restore and init.
        restore_model(sess, FLAGS.eval.model_path)
        _step = sess.run(global_step)
        infet_t = 0

        # Run test
        start_t = time.time()
        eval_cnts, eval_errs, eval_err_rates, eval_preds = \
            validate(sess,
                     _step,
                     eval_tower_outputs,
                     out_charset,
                     is_ctc,
                     summary_op,
                     summary_writer,
                     lowercase=FLAGS.eval.lowercase,
                     alphanumeric=FLAGS.eval.alphanumeric)
        infer_t = time.time() - start_t

    # Log
    total_total = 0

    for dataset, result in eval_preds.items():
        res_file = open(os.path.join(res_dir, '{}.txt'.format(dataset)), 'w')
        total = eval_cnts[dataset]
        correct = total - eval_errs[dataset]
        acc = 1. - eval_err_rates[dataset]
        total_total += total

        for f, s, g in result:
            f = f.decode('utf8')

            if FLAGS.eval.verbose:
                print('FILE : ' + f)
                print('PRED : ' + s)
                print('ANSW : ' + g)
                print('=' * 50)

            res_file.write('{}\t{}\n'.format(f, s))

        res_s = 'DATASET : %s\tCORRECT : %d\tTOTAL : %d\tACC : %f' % \
                (dataset, correct, total, acc)
        print(res_s)
        res_file.write(res_s)
        res_file.close()

    eval_loader.flush_tmpfile()
    print('INFER TIME(PER IMAGE) : %f s' % (float(infer_t) / total_total))