Exemplo n.º 1
0
def compute_accuracy(convert, ground_truth_labels, prediction_labels):

    ground_truth_str = convert.sparse_tensor_to_str(ground_truth_labels)
    predictions_str = convert.sparse_tensor_to_str(prediction_labels)
    accuracy = evaluation_tools.compute_accuracy(ground_truth_str,
                                                 predictions_str,
                                                 mode='full_sequence')

    return accuracy
Exemplo n.º 2
0
def train_shadownet(dataset_dir, weights_path, char_dict_path, ord_map_dict_path, need_decode=False):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param need_decode:
    :return:
    """
    # prepare dataset
    train_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='train'
    )
    val_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='val'
    )
    train_images, train_labels, train_images_paths = train_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE
    )
    val_images, val_labels, val_images_paths = val_dataset.inputs(
        batch_size=CFG.TRAIN.BATCH_SIZE
    )

    # declare crnn net
    shadownet = crnn_net.ShadowNet(
        phase='train',
        hidden_nums=CFG.ARCH.HIDDEN_UNITS,
        layers_nums=CFG.ARCH.HIDDEN_LAYERS,
        num_classes=CFG.ARCH.NUM_CLASSES
    )
    shadownet_val = crnn_net.ShadowNet(
        phase='test',
        hidden_nums=CFG.ARCH.HIDDEN_UNITS,
        layers_nums=CFG.ARCH.HIDDEN_LAYERS,
        num_classes=CFG.ARCH.NUM_CLASSES
    )

    # set up decoder
    decoder = tf_io_pipline_fast_tools.CrnnFeatureReader(
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path
    )

    # set up training graph
    with tf.device('/gpu:1'):

        # compute loss and seq distance
        train_inference_ret, train_ctc_loss = shadownet.compute_loss(
            inputdata=train_images,
            labels=train_labels,
            name='shadow_net',
            reuse=False
        )
        val_inference_ret, val_ctc_loss = shadownet_val.compute_loss(
            inputdata=val_images,
            labels=val_labels,
            name='shadow_net',
            reuse=True
        )

        train_decoded, train_log_prob = tf.nn.ctc_beam_search_decoder(
            train_inference_ret,
            CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE),
            merge_repeated=False
        )
        val_decoded, val_log_prob = tf.nn.ctc_beam_search_decoder(
            val_inference_ret,
            CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE),
            merge_repeated=False
        )

        train_sequence_dist = tf.reduce_mean(
            tf.edit_distance(tf.cast(train_decoded[0], tf.int32), train_labels),
            name='train_edit_distance'
        )
        val_sequence_dist = tf.reduce_mean(
            tf.edit_distance(tf.cast(val_decoded[0], tf.int32), val_labels),
            name='val_edit_distance'
        )

        # set learning rate
        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            learning_rate=CFG.TRAIN.LEARNING_RATE,
            global_step=global_step,
            decay_steps=CFG.TRAIN.LR_DECAY_STEPS,
            decay_rate=CFG.TRAIN.LR_DECAY_RATE,
            staircase=CFG.TRAIN.LR_STAIRCASE)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate, momentum=0.9).minimize(
                loss=train_ctc_loss, global_step=global_step)

    # Set tf summary
    tboard_save_dir = 'tboard/crnn_syn90k'
    os.makedirs(tboard_save_dir, exist_ok=True)
    tf.summary.scalar(name='train_ctc_loss', tensor=train_ctc_loss)
    tf.summary.scalar(name='val_ctc_loss', tensor=val_ctc_loss)
    tf.summary.scalar(name='learning_rate', tensor=learning_rate)

    if need_decode:
        tf.summary.scalar(name='train_seq_distance', tensor=train_sequence_dist)
        tf.summary.scalar(name='val_seq_distance', tensor=val_sequence_dist)

    merge_summary_op = tf.summary.merge_all()

    # Set saver configuration
    saver = tf.train.Saver()
    model_save_dir = 'model/crnn_syn90k'
    os.makedirs(model_save_dir, exist_ok=True)
    #train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    #model_name = 'shadownet_{:s}.ckpt'.format(str(train_start_time))
    model_name = 'shadownet.ckpt'
    model_save_path = ops.join(model_save_dir, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    summary_writer = tf.summary.FileWriter(tboard_save_dir)
    summary_writer.add_graph(sess.graph)

    # Set the training parameters
    train_epochs = CFG.TRAIN.EPOCHS

    with sess.as_default():
        epoch = 0
        if weights_path is None:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(weights_path))
            saver.restore(sess=sess, save_path=weights_path)
            epoch = sess.run(tf.train.get_global_step())

        patience_counter = 1
        cost_history = [np.inf]
        while epoch < train_epochs:
            epoch += 1
            # setup early stopping
            if epoch > 1 and CFG.TRAIN.EARLY_STOPPING:
                # We always compare to the first point where cost didn't improve
                if cost_history[-1 - patience_counter] - cost_history[-1] > CFG.TRAIN.PATIENCE_DELTA:
                    patience_counter = 1
                else:
                    patience_counter += 1
                if patience_counter > CFG.TRAIN.PATIENCE_EPOCHS:
                    logger.info("Cost didn't improve beyond {:f} for {:d} epochs, stopping early.".
                                format(CFG.TRAIN.PATIENCE_DELTA, patience_counter))
                    break

            if need_decode and epoch % 500 == 0:
                # train part
                _, train_ctc_loss_value, train_seq_dist_value, \
                    train_predictions, train_labels_sparse, merge_summary_value = sess.run(
                     [optimizer, train_ctc_loss, train_sequence_dist,
                      train_decoded, train_labels, merge_summary_op])

                train_labels_str = decoder.sparse_tensor_to_str(train_labels_sparse)
                train_predictions = decoder.sparse_tensor_to_str(train_predictions[0])
                avg_train_accuracy = evaluation_tools.compute_accuracy(train_labels_str, train_predictions)

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch_Train: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'.format(
                        epoch + 1, train_ctc_loss_value, train_seq_dist_value, avg_train_accuracy))

                # validation part
                val_ctc_loss_value, val_seq_dist_value, \
                    val_predictions, val_labels_sparse = sess.run(
                     [val_ctc_loss, val_sequence_dist, val_decoded, val_labels])

                val_labels_str = decoder.sparse_tensor_to_str(val_labels_sparse)
                val_predictions = decoder.sparse_tensor_to_str(val_predictions[0])
                avg_val_accuracy = evaluation_tools.compute_accuracy(val_labels_str, val_predictions)

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch_Val: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'.format(
                        epoch + 1, val_ctc_loss_value, val_seq_dist_value, avg_val_accuracy))
            else:
                _, train_ctc_loss_value, merge_summary_value = sess.run(
                    [optimizer, train_ctc_loss, merge_summary_op])

                if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch_Train: {:d} cost= {:9f}'.format(epoch + 1, train_ctc_loss_value))

            # record history train ctc loss
            cost_history.append(train_ctc_loss_value)
            # add training sumary
            summary_writer.add_summary(summary=merge_summary_value, global_step=epoch)

            if epoch % 2000 == 0:
                saver.save(sess=sess, save_path=model_save_path, global_step=epoch)

    return np.array(cost_history[1:])  # Don't return the first np.inf
def evaluate_shadownet(dataset_dir,
                       weights_path,
                       char_dict_path,
                       ord_map_dict_path,
                       is_visualize=False,
                       is_process_all_data=False):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param is_visualize:
    :param is_process_all_data:
    :return:
    """
    # prepare dataset
    test_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='test')
    test_images, test_labels, test_images_paths = test_dataset.inputs(
        batch_size=CFG.TEST.BATCH_SIZE)

    # set up test sample count
    if is_process_all_data:
        log.info('Start computing test dataset sample counts')
        t_start = time.time()
        test_sample_count = test_dataset.sample_counts()
        log.info(
            'Computing test dataset sample counts finished, cost time: {:.5f}'.
            format(time.time() - t_start))
        num_iterations = int(math.ceil(test_sample_count /
                                       CFG.TEST.BATCH_SIZE))
    else:
        num_iterations = 1

    # declare crnn net
    shadownet = crnn_net.ShadowNet(phase='test',
                                   hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                   layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                   num_classes=CFG.ARCH.NUM_CLASSES)
    # set up decoder
    decoder = tf_io_pipline_fast_tools.CrnnFeatureReader(
        char_dict_path=char_dict_path, ord_map_dict_path=ord_map_dict_path)

    # compute inference result
    test_inference_ret = shadownet.inference(inputdata=test_images,
                                             name='shadow_net',
                                             reuse=False)
    test_decoded, test_log_prob = tf.nn.ctc_beam_search_decoder(
        test_inference_ret,
        CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TEST.BATCH_SIZE),
        beam_width=1,
        merge_repeated=False)

    # recover image from [-1.0, 1.0] ---> [0.0, 255.0]
    test_images = tf.multiply(tf.add(test_images, 1.0),
                              127.5,
                              name='recoverd_test_images')

    # Set saver configuration
    saver = tf.train.Saver()

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        log.info('Start predicting...')

        per_char_accuracy = 0.0
        full_sequence_accuracy = 0.0

        total_labels_char_list = []
        total_predictions_char_list = []

        while True:
            try:

                for epoch in range(num_iterations):
                    test_predictions_value, test_images_value, test_labels_value, \
                     test_images_paths_value = sess.run(
                        [test_decoded, test_images, test_labels, test_images_paths])
                    test_images_paths_value = np.reshape(
                        test_images_paths_value,
                        newshape=test_images_paths_value.shape[0])
                    test_images_paths_value = [
                        tmp.decode('utf-8') for tmp in test_images_paths_value
                    ]
                    test_images_names_value = [
                        ops.split(tmp)[1] for tmp in test_images_paths_value
                    ]
                    test_labels_value = decoder.sparse_tensor_to_str(
                        test_labels_value)
                    test_predictions_value = decoder.sparse_tensor_to_str(
                        test_predictions_value[0])

                    per_char_accuracy += evaluation_tools.compute_accuracy(
                        test_labels_value,
                        test_predictions_value,
                        display=False,
                        mode='per_char')
                    full_sequence_accuracy += evaluation_tools.compute_accuracy(
                        test_labels_value,
                        test_predictions_value,
                        display=False,
                        mode='full_sequence')

                    for index, test_image in enumerate(test_images_value):
                        log.info(
                            'Predict {:s} image with gt label: {:s} **** predicted label: {:s}'
                            .format(test_images_names_value[index],
                                    test_labels_value[index],
                                    test_predictions_value[index]))

                        if is_visualize:
                            plt.imshow(
                                np.array(test_image, np.uint8)[:, :,
                                                               (2, 1, 0)])
                            plt.show()

                        test_labels_char_list_value = [
                            s for s in test_labels_value[index]
                        ]
                        test_predictions_char_list_value = [
                            s for s in test_predictions_value[index]
                        ]

                        if not test_labels_char_list_value or not test_predictions_char_list_value:
                            continue

                        if len(test_labels_char_list_value) != len(
                                test_predictions_char_list_value):
                            min_length = min(
                                len(test_labels_char_list_value),
                                len(test_predictions_char_list_value))
                            test_labels_char_list_value = test_labels_char_list_value[:
                                                                                      min_length
                                                                                      -
                                                                                      1]
                            test_predictions_char_list_value = test_predictions_char_list_value[:
                                                                                                min_length
                                                                                                -
                                                                                                1]

                        assert len(test_labels_char_list_value) == len(test_predictions_char_list_value), \
                            log.error('{}, {}'.format(test_labels_char_list_value, test_predictions_char_list_value))

                        total_labels_char_list.extend(
                            test_labels_char_list_value)
                        total_predictions_char_list.extend(
                            test_predictions_char_list_value)

                        if is_visualize:
                            plt.imshow(
                                np.array(test_image, np.uint8)[:, :,
                                                               (2, 1, 0)])

            except tf.errors.OutOfRangeError:
                log.error('End of tfrecords sequence')
                break
            except Exception as err:
                log.error(err)
                break

        avg_per_char_accuracy = per_char_accuracy / num_iterations
        avg_full_sequence_accuracy = full_sequence_accuracy / num_iterations
        log.info('Mean test per char accuracy is {:5f}'.format(
            avg_per_char_accuracy))
        log.info('Mean test full sequence accuracy is {:5f}'.format(
            avg_full_sequence_accuracy))

        # compute confusion matrix
        cnf_matrix = confusion_matrix(total_labels_char_list,
                                      total_predictions_char_list)
        np.set_printoptions(precision=2)
        evaluation_tools.plot_confusion_matrix(cm=cnf_matrix, normalize=True)

        plt.show()
Exemplo n.º 4
0
def evaluate_shadownet(dataset_dir,
                       weights_path,
                       char_dict_path,
                       ord_map_dict_path,
                       is_visualize=False,
                       is_process_all_data=False):
    """

    :param dataset_dir:
    :param weights_path:
    :param char_dict_path:
    :param ord_map_dict_path:
    :param is_visualize:
    :param is_process_all_data:
    :return:
    """
    # prepare dataset
    test_dataset = shadownet_data_feed_pipline.CrnnDataFeeder(
        dataset_dir=dataset_dir,
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path,
        flags='test')
    test_images, test_labels, test_images_paths = test_dataset.inputs(
        batch_size=CFG.TEST.BATCH_SIZE, num_epochs=1)

    # set up test sample count
    if is_process_all_data:
        log.info('Start computing test dataset sample counts')
        t_start = time.time()
        test_sample_count = test_dataset.sample_counts()
        log.info(
            'Computing test dataset sample counts finished, cost time: {:.5f}'.
            format(time.time() - t_start))
        num_iterations = int(math.ceil(test_sample_count /
                                       CFG.TEST.BATCH_SIZE))
    else:
        num_iterations = 1

    # declare crnn net
    shadownet = crnn_model.ShadowNet(phase='test',
                                     hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                                     layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                                     num_classes=CFG.ARCH.NUM_CLASSES)
    # set up decoder
    decoder = tf_io_pipline_tools.TextFeatureIO(
        char_dict_path=char_dict_path,
        ord_map_dict_path=ord_map_dict_path).reader

    # compute inference result
    test_inference_ret = shadownet.inference(inputdata=test_images,
                                             name='shadow_net',
                                             reuse=False)
    test_decoded, test_log_prob = tf.nn.ctc_beam_search_decoder(
        test_inference_ret,
        CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TEST.BATCH_SIZE),
        beam_width=1,
        merge_repeated=False)

    # recover image from [-1.0, 1.0] ---> [0.0, 255.0]
    test_images = tf.multiply(tf.add(test_images, 1.0),
                              127.5,
                              name='recoverd_test_images')

    # Set saver configuration
    saver = tf.train.Saver()

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        log.info('Start predicting...')

        accuracy = 0

        while True:
            try:

                for epoch in range(num_iterations):
                    test_predictions_value, test_images_value, test_labels_value, \
                    test_images_paths_value = sess.run(
                        [test_decoded, test_images, test_labels, test_images_paths]
                    )
                    test_images_paths_value = np.reshape(
                        test_images_paths_value,
                        newshape=test_images_paths_value.shape[0])
                    test_images_paths_value = [
                        tmp.decode('utf-8') for tmp in test_images_paths_value
                    ]
                    test_images_names_value = [
                        ops.split(tmp)[1] for tmp in test_images_paths_value
                    ]
                    test_labels_value = decoder.sparse_tensor_to_str(
                        test_labels_value)
                    test_predictions_value = decoder.sparse_tensor_to_str(
                        test_predictions_value[0])

                    accuracy += evaluation_tools.compute_accuracy(
                        test_labels_value,
                        test_predictions_value,
                        display=False)

                    for index, test_image in enumerate(test_images_value):
                        print(
                            'Predict {:s} image with gt label: {:s} **** predicted label: {:s}'
                            .format(test_images_names_value[index],
                                    test_labels_value[index],
                                    test_predictions_value[index]))

                        # avoid accidentally displaying for the whole dataset
                        if is_visualize:
                            plt.imshow(
                                np.array(test_image, np.uint8)[:, :,
                                                               (2, 1, 0)])
                            plt.show()
            except tf.errors.OutOfRangeError:
                print('End of tfrecords sequence')
                break
            except Exception as err:
                print(err)
                break

        # we compute a mean of means, so we need the sample sizes to be constant
        # (BATCH_SIZE) for this to equal the actual mean
        accuracy /= num_iterations
        log.info('Mean test accuracy is {:5f}'.format(accuracy))