def train(dataset_dir, weights_path, train_data_num, gpu_num):
    """
    train
    :param dataset_dir:
    :param weights_path:
    :param train_data_num:
    :return:
    """
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    # sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    # sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess = tf.Session(config=sess_config)
    # prepare dataset
    train_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir, flags='train')
    val_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir, flags='val')
    train_images, train_labels, train_images_paths = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE)
    val_images, val_labels, val_images_paths = val_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE)
    train_epochs = train_data_num // (batch_size * gpu_num) * 100
    val_epochs = train_data_num // (batch_size * gpu_num)
    save_epochs = train_data_num // (batch_size * gpu_num)
    show_epochs = 100
    decoder = tf_io_pipline_fast_tools.FeatureDecoder(
        lexicon_path=os.path.join(dataset_dir + "lexicon.txt"))
    chinese_crnn = ChineseCrnnNet(hidden_nums=hidden_nums,
                                  layers_nums=hidden_layers,
                                  num_classes=num_classes,
                                  pretrained_model=weights_path,
                                  sess=sess,
                                  feature_decoder=decoder,
                                  learning_rate=learning_rate,
                                  lr_decay_steps=lr_decay_steps,
                                  lr_decay_rate=lr_decay_rate,
                                  lr_staircase=lr_staircase)

    chinese_crnn.multi_gpu_train(gpu_num=gpu_num,
                                 train_input_data=train_images,
                                 train_label=train_labels,
                                 val_input_data=val_images,
                                 val_label=val_labels,
                                 sql_len=seq_len,
                                 batch_size=batch_size,
                                 name="chinese_crnn",
                                 val_epochs=val_epochs,
                                 train_epochs=train_epochs,
                                 save_epochs=save_epochs,
                                 show_epochs=show_epochs)
def train(dataset_dir, weights_path, train_data_num):
    """
    train
    :param dataset_dir:
    :param weights_path:
    :param train_data_num:
    :return:
    """
    #sess_config = tf.ConfigProto(allow_soft_placement=True)
    #sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    #sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess = tf.Session()
    # prepare dataset
    train_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir, flags='train')
    val_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir, flags='val')
    train_images, train_labels, train_images_paths = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE)
    val_images, val_labels, val_images_paths = val_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE)
    mean = [102.9801, 115.9465, 122.7717]
    decoder = tf_io_pipline_fast_tools.FeatureDecoder(
        lexicon_path=os.path.join(dataset_dir + "lexicon.txt"))
    chinese_crnn = ChineseCrnnNet(hidden_nums=hidden_nums,
                                  layers_nums=hidden_layers,
                                  num_classes=num_classes,
                                  pretrained_model=weights_path,
                                  sess=sess,
                                  feature_decoder=decoder,
                                  learning_rate=learning_rate,
                                  lr_decay_steps=lr_decay_steps,
                                  lr_decay_rate=lr_decay_rate,
                                  lr_staircase=lr_staircase)
    chinese_crnn.train(train_input_data=train_images,
                       train_label=train_labels,
                       val_input_data=val_images,
                       val_label=val_labels,
                       sql_len=seq_len,
                       batch_size=batch_size,
                       name="chinese_crnn",
                       train_data_num=train_data_num,
                       val_times=val_times,
                       train_epochs=epochs,
                       show_step=show_step,
                       val_step=val_step,
                       need_decode=need_decode,
                       tboard_save_dir='../tboard/crnn_chinese_ocr',
                       model_save_dir='../ckpt/chinese_ocr')
Beispiel #3
0
def validation_data(dataset_dir, weights_path):
    """

    :param dataset_dir:
    :param weights_path:
    :param train_data_num:
    :return:
    """
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=sess_config)
    val_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir, flags='val')
    val_images, val_labels, val_images_paths = val_dataset.inputs(
        batch_size=CFG.TRAIN.VAL_BATCH_SIZE)
    decoder = tf_io_pipline_fast_tools.FeatureDecoder(
        lexicon_path=os.path.join(dataset_dir + "lexicon.txt"))
    chinese_crnn = ChineseCrnnNet(hidden_nums=hidden_nums,
                                  layers_nums=hidden_layers,
                                  num_classes=num_classes,
                                  pretrained_model=weights_path,
                                  sess=sess,
                                  feature_decoder=decoder,
                                  learning_rate=learning_rate,
                                  lr_decay_steps=lr_decay_steps,
                                  lr_decay_rate=lr_decay_rate,
                                  lr_staircase=lr_staircase)
    chinese_crnn.validation(val_images,
                            val_labels,
                            sql_len=seq_len,
                            batch_size=batch_size,
                            val_times=val_times,
                            name="chinese_crnn")
def evaluate_shadownet(dataset_dir,
                       weights_path,
                       char_dict_path,
                       ord_map_dict_path,
                       is_visualize=False,
                       is_process_all_data=False):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param is_visualize:
    :param is_process_all_data:
    :return:
    """
    # prepare dataset
    test_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='test')
    test_images, test_labels, test_images_paths = test_dataset.inputs(
        batch_size=CFG.TEST.BATCH_SIZE)

    # set up test sample count
    if is_process_all_data:
        log.info('Start computing test dataset sample counts')
        t_start = time.time()
        test_sample_count = test_dataset.sample_counts()
        log.info(
            'Computing test dataset sample counts finished, cost time: {:.5f}'.
            format(time.time() - t_start))
        num_iterations = int(math.ceil(test_sample_count /
                                       CFG.TEST.BATCH_SIZE))
    else:
        num_iterations = 1

    # declare crnn net
    shadownet = crnn_net.ShadowNet(phase='test',
                                   hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                   layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                   num_classes=CFG.ARCH.NUM_CLASSES)
    # set up decoder
    decoder = tf_io_pipline_fast_tools.CrnnFeatureReader(
        char_dict_path=char_dict_path, ord_map_dict_path=ord_map_dict_path)

    # compute inference result
    test_inference_ret = shadownet.inference(inputdata=test_images,
                                             name='shadow_net',
                                             reuse=False)
    test_decoded, test_log_prob = tf.nn.ctc_beam_search_decoder(
        test_inference_ret,
        CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TEST.BATCH_SIZE),
        beam_width=1,
        merge_repeated=False)

    # recover image from [-1.0, 1.0] ---> [0.0, 255.0]
    test_images = tf.multiply(tf.add(test_images, 1.0),
                              127.5,
                              name='recoverd_test_images')

    # Set saver configuration
    saver = tf.train.Saver()

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        log.info('Start predicting...')

        per_char_accuracy = 0.0
        full_sequence_accuracy = 0.0

        total_labels_char_list = []
        total_predictions_char_list = []

        while True:
            try:

                for epoch in range(num_iterations):
                    test_predictions_value, test_images_value, test_labels_value, \
                     test_images_paths_value = sess.run(
                        [test_decoded, test_images, test_labels, test_images_paths])
                    test_images_paths_value = np.reshape(
                        test_images_paths_value,
                        newshape=test_images_paths_value.shape[0])
                    test_images_paths_value = [
                        tmp.decode('utf-8') for tmp in test_images_paths_value
                    ]
                    test_images_names_value = [
                        ops.split(tmp)[1] for tmp in test_images_paths_value
                    ]
                    test_labels_value = decoder.sparse_tensor_to_str(
                        test_labels_value)
                    test_predictions_value = decoder.sparse_tensor_to_str(
                        test_predictions_value[0])

                    per_char_accuracy += evaluation_tools.compute_accuracy(
                        test_labels_value,
                        test_predictions_value,
                        display=False,
                        mode='per_char')
                    full_sequence_accuracy += evaluation_tools.compute_accuracy(
                        test_labels_value,
                        test_predictions_value,
                        display=False,
                        mode='full_sequence')

                    for index, test_image in enumerate(test_images_value):
                        log.info(
                            'Predict {:s} image with gt label: {:s} **** predicted label: {:s}'
                            .format(test_images_names_value[index],
                                    test_labels_value[index],
                                    test_predictions_value[index]))

                        if is_visualize:
                            plt.imshow(
                                np.array(test_image, np.uint8)[:, :,
                                                               (2, 1, 0)])
                            plt.show()

                        test_labels_char_list_value = [
                            s for s in test_labels_value[index]
                        ]
                        test_predictions_char_list_value = [
                            s for s in test_predictions_value[index]
                        ]

                        if not test_labels_char_list_value or not test_predictions_char_list_value:
                            continue

                        if len(test_labels_char_list_value) != len(
                                test_predictions_char_list_value):
                            min_length = min(
                                len(test_labels_char_list_value),
                                len(test_predictions_char_list_value))
                            test_labels_char_list_value = test_labels_char_list_value[:
                                                                                      min_length
                                                                                      -
                                                                                      1]
                            test_predictions_char_list_value = test_predictions_char_list_value[:
                                                                                                min_length
                                                                                                -
                                                                                                1]

                        assert len(test_labels_char_list_value) == len(test_predictions_char_list_value), \
                            log.error('{}, {}'.format(test_labels_char_list_value, test_predictions_char_list_value))

                        total_labels_char_list.extend(
                            test_labels_char_list_value)
                        total_predictions_char_list.extend(
                            test_predictions_char_list_value)

                        if is_visualize:
                            plt.imshow(
                                np.array(test_image, np.uint8)[:, :,
                                                               (2, 1, 0)])

            except tf.errors.OutOfRangeError:
                log.error('End of tfrecords sequence')
                break
            except Exception as err:
                log.error(err)
                break

        avg_per_char_accuracy = per_char_accuracy / num_iterations
        avg_full_sequence_accuracy = full_sequence_accuracy / num_iterations
        log.info('Mean test per char accuracy is {:5f}'.format(
            avg_per_char_accuracy))
        log.info('Mean test full sequence accuracy is {:5f}'.format(
            avg_full_sequence_accuracy))

        # compute confusion matrix
        cnf_matrix = confusion_matrix(total_labels_char_list,
                                      total_predictions_char_list)
        np.set_printoptions(precision=2)
        evaluation_tools.plot_confusion_matrix(cm=cnf_matrix, normalize=True)

        plt.show()
Beispiel #5
0
def train_shadownet_multi_gpu(dataset_dir, weights_path, char_dict_path, ord_map_dict_path):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :return:
    """
    # prepare dataset information
    train_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='train'
    )
    val_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='val'
    )
    train_images, train_labels, train_images_paths = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE
    )
    val_images, val_labels, val_images_paths = val_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE
    )

    # set crnn net
    shadownet = crnn_net.ShadowNet(
        phase='train',
        hidden_nums=CFG.ARCH.HIDDEN_UNITS,
        layers_nums=CFG.ARCH.HIDDEN_LAYERS,
        num_classes=CFG.ARCH.NUM_CLASSES
    )
    shadownet_val = crnn_net.ShadowNet(
        phase='test',
        hidden_nums=CFG.ARCH.HIDDEN_UNITS,
        layers_nums=CFG.ARCH.HIDDEN_LAYERS,
        num_classes=CFG.ARCH.NUM_CLASSES
    )

    # set average container
    tower_grads = []
    train_tower_loss = []
    val_tower_loss = []
    batchnorm_updates = None
    train_summary_op_updates = None

    # set lr
    global_step = tf.Variable(0, name='global_step', trainable=False)
    learning_rate = tf.train.exponential_decay(
        learning_rate=CFG.TRAIN.LEARNING_RATE,
        global_step=global_step,
        decay_steps=CFG.TRAIN.LR_DECAY_STEPS,
        decay_rate=CFG.TRAIN.LR_DECAY_RATE,
        staircase=CFG.TRAIN.LR_STAIRCASE)

    # set up optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)

    # set distributed train op
    with tf.variable_scope(tf.get_variable_scope()):
        is_network_initialized = False
        for i in range(CFG.TRAIN.GPU_NUM):
            with tf.device('/gpu:{:d}'.format(i)):
                with tf.name_scope('tower_{:d}'.format(i)) as _:
                    train_loss, grads = compute_net_gradients(
                        train_images, train_labels, shadownet, optimizer,
                        is_net_first_initialized=is_network_initialized)

                    is_network_initialized = True

                    # Only use the mean and var in the first gpu tower to update the parameter
                    if i == 0:
                        batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                        train_summary_op_updates = tf.get_collection(tf.GraphKeys.SUMMARIES)

                    tower_grads.append(grads)
                    train_tower_loss.append(train_loss)
                with tf.name_scope('validation_{:d}'.format(i)) as _:
                    val_loss, _ = compute_net_gradients(
                        val_images, val_labels, shadownet_val, optimizer,
                        is_net_first_initialized=is_network_initialized)
                    val_tower_loss.append(val_loss)

    grads = average_gradients(tower_grads)
    avg_train_loss = tf.reduce_mean(train_tower_loss)
    avg_val_loss = tf.reduce_mean(val_tower_loss)

    # Track the moving averages of all trainable variables
    variable_averages = tf.train.ExponentialMovingAverage(
        CFG.TRAIN.MOVING_AVERAGE_DECAY, num_updates=global_step)
    variables_to_average = tf.trainable_variables() + tf.moving_average_variables()
    variables_averages_op = variable_averages.apply(variables_to_average)

    # Group all the op needed for training
    batchnorm_updates_op = tf.group(*batchnorm_updates)
    apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step)
    train_op = tf.group(apply_gradient_op, variables_averages_op,
                        batchnorm_updates_op)

    # set tensorflow summary
    tboard_save_path = 'tboard/crnn_syn90k_multi_gpu'
    os.makedirs(tboard_save_path, exist_ok=True)

    summary_writer = tf.summary.FileWriter(tboard_save_path)

    avg_train_loss_scalar = tf.summary.scalar(name='average_train_loss',
                                              tensor=avg_train_loss)
    avg_val_loss_scalar = tf.summary.scalar(name='average_val_loss',
                                            tensor=avg_val_loss)
    learning_rate_scalar = tf.summary.scalar(name='learning_rate_scalar',
                                             tensor=learning_rate)
    train_merge_summary_op = tf.summary.merge(
        [avg_train_loss_scalar, learning_rate_scalar] + train_summary_op_updates
    )
    val_merge_summary_op = tf.summary.merge([avg_val_loss_scalar])

    # set tensorflow saver
    saver = tf.train.Saver()
    model_save_dir = 'model/crnn_syn90k_multi_gpu'
    os.makedirs(model_save_dir, exist_ok=True)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = ops.join(model_save_dir, model_name)

    # set sess config
    sess_config = tf.ConfigProto(device_count={'GPU': CFG.TRAIN.GPU_NUM}, allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    sess_config.gpu_options.allocator_type = 'BFC'

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    logger.info('Global configuration is as follows:')
    logger.info(CFG)

    sess = tf.Session(config=sess_config)

    summary_writer.add_graph(sess.graph)

    with sess.as_default():
        epoch = 0
        tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
                             name='{:s}/shadownet_model.pb'.format(model_save_dir))

        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from last model checkpoint {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)
            epoch = sess.run(tf.train.get_global_step())

        train_cost_time_mean = []
        val_cost_time_mean = []

        while epoch < train_epochs:
            epoch += 1
            # training part
            t_start = time.time()

            _, train_loss_value, train_summary, lr = \
                sess.run(fetches=[train_op,
                                  avg_train_loss,
                                  train_merge_summary_op,
                                  learning_rate])

            if math.isnan(train_loss_value):
                raise ValueError('Train loss is nan')

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)

            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # validation part
            t_start_val = time.time()

            val_loss_value, val_summary = \
                sess.run(fetches=[avg_val_loss,
                                  val_merge_summary_op])

            summary_writer.add_summary(val_summary, global_step=epoch)

            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                logger.info('Epoch_Train: {:d} total_loss= {:6f} '
                            'lr= {:6f} mean_cost_time= {:5f}s '.
                            format(epoch + 1,
                                   train_loss_value,
                                   lr,
                                   np.mean(train_cost_time_mean)
                                   ))
                train_cost_time_mean.clear()

            if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
                logger.info('Epoch_Val: {:d} total_loss= {:6f} '
                            ' mean_cost_time= {:5f}s '.
                            format(epoch + 1,
                                   val_loss_value,
                                   np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 5000 == 0:
                saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
    sess.close()

    return
Beispiel #6
0
def train_shadownet(dataset_dir, weights_path, char_dict_path, ord_map_dict_path, need_decode=False):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param need_decode:
    :return:
    """
    # prepare dataset
    train_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='train'
    )
    val_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='val'
    )
    train_images, train_labels, train_images_paths = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE
    )
    val_images, val_labels, val_images_paths = val_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE
    )

    # declare crnn net
    shadownet = crnn_net.ShadowNet(
        phase='train',
        hidden_nums=CFG.ARCH.HIDDEN_UNITS,
        layers_nums=CFG.ARCH.HIDDEN_LAYERS,
        num_classes=CFG.ARCH.NUM_CLASSES
    )
    shadownet_val = crnn_net.ShadowNet(
        phase='test',
        hidden_nums=CFG.ARCH.HIDDEN_UNITS,
        layers_nums=CFG.ARCH.HIDDEN_LAYERS,
        num_classes=CFG.ARCH.NUM_CLASSES
    )

    # set up decoder
    decoder = tf_io_pipline_fast_tools.CrnnFeatureReader(
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path
    )

    # set up training graph
    with tf.device('/gpu:1'):

        # compute loss and seq distance
        train_inference_ret, train_ctc_loss = shadownet.compute_loss(
            inputdata=train_images,
            labels=train_labels,
            name='shadow_net',
            reuse=False
        )
        val_inference_ret, val_ctc_loss = shadownet_val.compute_loss(
            inputdata=val_images,
            labels=val_labels,
            name='shadow_net',
            reuse=True
        )

        train_decoded, train_log_prob = tf.nn.ctc_beam_search_decoder(
            train_inference_ret,
            CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE),
            merge_repeated=False
        )
        val_decoded, val_log_prob = tf.nn.ctc_beam_search_decoder(
            val_inference_ret,
            CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE),
            merge_repeated=False
        )

        train_sequence_dist = tf.reduce_mean(
            tf.edit_distance(tf.cast(train_decoded[0], tf.int32), train_labels),
            name='train_edit_distance'
        )
        val_sequence_dist = tf.reduce_mean(
            tf.edit_distance(tf.cast(val_decoded[0], tf.int32), val_labels),
            name='val_edit_distance'
        )

        # set learning rate
        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            learning_rate=CFG.TRAIN.LEARNING_RATE,
            global_step=global_step,
            decay_steps=CFG.TRAIN.LR_DECAY_STEPS,
            decay_rate=CFG.TRAIN.LR_DECAY_RATE,
            staircase=CFG.TRAIN.LR_STAIRCASE)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate, momentum=0.9).minimize(
                loss=train_ctc_loss, global_step=global_step)

    # Set tf summary
    tboard_save_dir = 'tboard/crnn_syn90k'
    os.makedirs(tboard_save_dir, exist_ok=True)
    tf.summary.scalar(name='train_ctc_loss', tensor=train_ctc_loss)
    tf.summary.scalar(name='val_ctc_loss', tensor=val_ctc_loss)
    tf.summary.scalar(name='learning_rate', tensor=learning_rate)

    if need_decode:
        tf.summary.scalar(name='train_seq_distance', tensor=train_sequence_dist)
        tf.summary.scalar(name='val_seq_distance', tensor=val_sequence_dist)

    merge_summary_op = tf.summary.merge_all()

    # Set saver configuration
    saver = tf.train.Saver()
    model_save_dir = 'model/crnn_syn90k'
    os.makedirs(model_save_dir, exist_ok=True)
    #train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    #model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_name = 'shadownet.ckpt'
    model_save_path = ops.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_dir)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    with sess.as_default():
        epoch = 0
        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)
            epoch = sess.run(tf.train.get_global_step())

        patience_counter = 1
        cost_history = [np.inf]
        while epoch < train_epochs:
            epoch += 1
            # setup early stopping
            if epoch > 1 and CFG.TRAIN.EARLY_STOPPING:
                # We always compare to the first point where cost didn't improve
                if cost_history[-1 - patience_counter] - cost_history[-1] > CFG.TRAIN.PATIENCE_DELTA:
                    patience_counter = 1
                else:
                    patience_counter += 1
                if patience_counter > CFG.TRAIN.PATIENCE_EPOCHS:
                    logger.info("Cost didn't improve beyond {:f} for {:d} epochs, stopping early.".
                                format(CFG.TRAIN.PATIENCE_DELTA, patience_counter))
                    break

            if need_decode and epoch % 500 == 0:
                # train part
                _, train_ctc_loss_value, train_seq_dist_value, \
                    train_predictions, train_labels_sparse, merge_summary_value = sess.run(
                     [optimizer, train_ctc_loss, train_sequence_dist,
                      train_decoded, train_labels, merge_summary_op])

                train_labels_str = decoder.sparse_tensor_to_str(train_labels_sparse)
                train_predictions = decoder.sparse_tensor_to_str(train_predictions[0])
                avg_train_accuracy = evaluation_tools.compute_accuracy(train_labels_str, train_predictions)

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch_Train: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'.format(
                        epoch + 1, train_ctc_loss_value, train_seq_dist_value, avg_train_accuracy))

                # validation part
                val_ctc_loss_value, val_seq_dist_value, \
                    val_predictions, val_labels_sparse = sess.run(
                     [val_ctc_loss, val_sequence_dist, val_decoded, val_labels])

                val_labels_str = decoder.sparse_tensor_to_str(val_labels_sparse)
                val_predictions = decoder.sparse_tensor_to_str(val_predictions[0])
                avg_val_accuracy = evaluation_tools.compute_accuracy(val_labels_str, val_predictions)

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch_Val: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'.format(
                        epoch + 1, val_ctc_loss_value, val_seq_dist_value, avg_val_accuracy))
            else:
                _, train_ctc_loss_value, merge_summary_value = sess.run(
                    [optimizer, train_ctc_loss, merge_summary_op])

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch_Train: {:d} cost= {:9f}'.format(epoch + 1, train_ctc_loss_value))

            # record history train ctc loss
            cost_history.append(train_ctc_loss_value)
            # add training sumary
            summary_writer.add_summary(summary=merge_summary_value, global_step=epoch)

            if epoch % 2000 == 0:
                saver.save(sess=sess, save_path=model_save_path, global_step=epoch)

    return np.array(cost_history[1:])  # Don't return the first np.inf
Beispiel #7
0
def evaluate_shadownet(dataset_dir,
                       weights_path,
                       char_dict_path,
                       ord_map_dict_path,
                       is_visualize=False,
                       is_process_all_data=False):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param is_visualize:
    :param is_process_all_data:
    :return:
    """
    # prepare dataset
    test_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='test')
    test_images, test_labels, test_images_paths = test_dataset.inputs(
        batch_size=CFG.TEST.BATCH_SIZE, num_epochs=1)

    # set up test sample count
    if is_process_all_data:
        log.info('Start computing test dataset sample counts')
        t_start = time.time()
        test_sample_count = test_dataset.sample_counts()
        log.info(
            'Computing test dataset sample counts finished, cost time: {:.5f}'.
            format(time.time() - t_start))
        num_iterations = int(math.ceil(test_sample_count /
                                       CFG.TEST.BATCH_SIZE))
    else:
        num_iterations = 1

    # declare crnn net
    shadownet = crnn_model.ShadowNet(phase='test',
                                     hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                     layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                     num_classes=CFG.ARCH.NUM_CLASSES)
    # set up decoder
    decoder = tf_io_pipline_tools.TextFeatureIO(
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path).reader

    # compute inference result
    test_inference_ret = shadownet.inference(inputdata=test_images,
                                             name='shadow_net',
                                             reuse=False)
    test_decoded, test_log_prob = tf.nn.ctc_beam_search_decoder(
        test_inference_ret,
        CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TEST.BATCH_SIZE),
        beam_width=1,
        merge_repeated=False)

    # recover image from [-1.0, 1.0] ---> [0.0, 255.0]
    test_images = tf.multiply(tf.add(test_images, 1.0),
                              127.5,
                              name='recoverd_test_images')

    # Set saver configuration
    saver = tf.train.Saver()

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        log.info('Start predicting...')

        accuracy = 0

        while True:
            try:

                for epoch in range(num_iterations):
                    test_predictions_value, test_images_value, test_labels_value, \
                    test_images_paths_value = sess.run(
                        [test_decoded, test_images, test_labels, test_images_paths]
                    )
                    test_images_paths_value = np.reshape(
                        test_images_paths_value,
                        newshape=test_images_paths_value.shape[0])
                    test_images_paths_value = [
                        tmp.decode('utf-8') for tmp in test_images_paths_value
                    ]
                    test_images_names_value = [
                        ops.split(tmp)[1] for tmp in test_images_paths_value
                    ]
                    test_labels_value = decoder.sparse_tensor_to_str(
                        test_labels_value)
                    test_predictions_value = decoder.sparse_tensor_to_str(
                        test_predictions_value[0])

                    accuracy += evaluation_tools.compute_accuracy(
                        test_labels_value,
                        test_predictions_value,
                        display=False)

                    for index, test_image in enumerate(test_images_value):
                        print(
                            'Predict {:s} image with gt label: {:s} **** predicted label: {:s}'
                            .format(test_images_names_value[index],
                                    test_labels_value[index],
                                    test_predictions_value[index]))

                        # avoid accidentally displaying for the whole dataset
                        if is_visualize:
                            plt.imshow(
                                np.array(test_image, np.uint8)[:, :,
                                                               (2, 1, 0)])
                            plt.show()
            except tf.errors.OutOfRangeError:
                print('End of tfrecords sequence')
                break
            except Exception as err:
                print(err)
                break

        # we compute a mean of means, so we need the sample sizes to be constant
        # (BATCH_SIZE) for this to equal the actual mean
        accuracy /= num_iterations
        log.info('Mean test accuracy is {:5f}'.format(accuracy))