def main(_):
    with tf.Graph().as_default():
        out_shape=[FLAGS.train_image_size] * 2

        image_input=tf.placeholder(tf.uint8, shape=(None, None, 3))
        shape_input=tf.placeholder(tf.int32, shape=(2,))

        features, output_shape=\
            textboxes_plusplus_preprocessing.preprocess_for_eval(
                image_input,
                out_shape,
                data_format=FLAGS.data_format,
                output_rgb=False)
        features=tf.expand_dims(features, axis=0) # (1, ?, ?, 3)
        output_shape=tf.expand_dims(output_shape, axis=0) # (1, 2)

        with tf.variable_scope(FLAGS.model_scope,
                               default_name=None,
                               values=[features],
                               reuse=tf.AUTO_REUSE):
            with tf.device('/cpu:0'):
                anchor_processor=\
                    anchor_manipulator.AnchorProcessor(
                        positive_threshold=None,
                        ignore_threshold=None,
                        prior_scaling=config.PRIOR_SCALING)

                anchor_heights_all_layers,\
                anchor_widths_all_layers,\
                num_anchors_per_location_all_layers=\
                    anchor_processor.get_anchors_size_all_layers(
                        config.ALL_ANCHOR_SCALES,
                        config.ALL_EXTRA_SCALES,
                        config.ALL_ANCHOR_RATIOS,
                        config.NUM_FEATURE_LAYERS)

                # shape=(num_anchors_all_layers,).
                anchors_ymin,\
                anchors_xmin,\
                anchors_ymax,\
                anchors_xmax,\
                _=\
                    anchor_processor.get_all_anchors_all_layers(
                        tf.squeeze(output_shape, axis=0),
                        anchor_heights_all_layers,
                        anchor_widths_all_layers,
                        num_anchors_per_location_all_layers,
                        config.ANCHOR_OFFSETS,
                        config.VERTICAL_OFFSETS,
                        config.ALL_LAYER_SHAPES,
                        config.ALL_LAYER_STRIDES,
                        [0.] * config.NUM_FEATURE_LAYERS,
                        [False] * config.NUM_FEATURE_LAYERS)

                backbone=textboxes_plusplus_net.VGG16Backbone(FLAGS.data_format)
                feature_layers=backbone.forward(features, training=False)
                # shape=(num_features,
                #        bs,
                #        fh,
                #        fw,
                #        num_anchors_per_locations * 2 * num_offsets)
                location_predictions, class_predictions=\
                    textboxes_plusplus_net.multibox_head(
                        feature_layers,
                        FLAGS.num_classes,
                        config.NUM_OFFSETS,
                        num_anchors_per_location_all_layers,
                        data_format=FLAGS.data_format)
                if FLAGS.data_format == 'channels_first':
                    class_predictions=\
                        [tf.transpose(pred,
                                      [0, 2, 3, 1])\
                         for pred in class_predictions]
                    location_predictions=\
                        [tf.transpose(pred,
                                      [0, 2, 3, 1])\
                         for pred in location_predictions]
                class_predictions=\
                    [tf.reshape(pred,
                                [-1, FLAGS.num_classes])\
                     for pred in class_predictions]
                location_predictions=\
                    [tf.reshape(pred, [-1, config.NUM_OFFSETS])\
                     for pred in location_predictions]

                class_predictions=tf.concat(class_predictions, axis=0)
                location_predictions=tf.concat(location_predictions, axis=0)

                # total_parameters = 0
                # for variable in tf.trainable_variables():
                #     # shape is an array of tf.Dimension
                #     shape = variable.get_shape()
                #     print(shape)
                #     print(len(shape))
                #     variable_parameters = 1
                #     for dim in shape:
                #         print(dim)
                #         variable_parameters *= dim.value
                #     print(variable_parameters)
                #     total_parameters += variable_parameters
                # print(total_parameters)


        with tf.device('/cpu:0'):
            bboxes_pred, quadrilaterals_pred=\
                anchor_processor.decode_anchors(
                    location_predictions,
                    anchors_ymin,
                    anchors_xmin,
                    anchors_ymax,
                    anchors_xmax)
            selected_bboxes,\
            selected_quadrilaterals,\
            selected_scores=\
                bbox_util.parse_by_class(
                    tf.squeeze(output_shape, axis=0),
                    class_predictions,
                    bboxes_pred,
                    quadrilaterals_pred,
                    FLAGS.num_classes,
                    FLAGS.select_threshold,
                    FLAGS.min_size,
                    FLAGS.keep_topk,
                    FLAGS.nms_topk,
                    FLAGS.nms_threshold)

            labels_list=[]
            scores_list=[]
            bboxes_list=[]
            quadrilaterals_list=[]
            for k, v in selected_scores.items():
                labels_list.append(tf.ones_like(v, tf.int32) * k)
                scores_list.append(v)
                bboxes_list.append(selected_bboxes[k])
                quadrilaterals_list.append(selected_quadrilaterals[k])
            all_labels=tf.concat(labels_list, axis=0)
            all_scores=tf.concat(scores_list, axis=0)
            all_bboxes=tf.concat(bboxes_list, axis=0)
            all_quadrilaterals=tf.concat(quadrilaterals_list, axis=0)

        saver=tf.train.Saver()
        with tf.Session() as sess:
            init=tf.global_variables_initializer()
            sess.run(init)

            saver.restore(sess, get_checkpoint())

            total_time=0
            # np_image=imread('./demo/' + FLAGS.image_file_name)
            image_files_name=sorted(os.listdir(FLAGS.source_directory))
            for i, image_file_name in enumerate(image_files_name):
                np_image=imread(os.path.join(FLAGS.source_directory, image_file_name))
                start_time=time.time()

                labels_,\
                scores_,\
                bboxes_,\
                quadrilaterals_,\
                output_shape_=\
                    sess.run([all_labels,
                              all_scores,
                              all_bboxes,
                              all_quadrilaterals,
                              output_shape],
                             feed_dict={image_input : np_image,
                                        shape_input : np_image.shape[:-1]})

                elapsed_time=time.time() - start_time
                print('{}: elapsed_time = {}'.format(i + 1, elapsed_time))
                total_time+=elapsed_time

                bboxes_[:, 0]=bboxes_[:, 0] * np_image.shape[0] / output_shape_[0, 0]
                bboxes_[:, 1]=bboxes_[:, 1] * np_image.shape[1] / output_shape_[0, 1]
                bboxes_[:, 2]=bboxes_[:, 2] * np_image.shape[0] / output_shape_[0, 0]
                bboxes_[:, 3]=bboxes_[:, 3] * np_image.shape[1] / output_shape_[0, 1]
                quadrilaterals_[:, 0]=quadrilaterals_[:, 0] * np_image.shape[0] / output_shape_[0, 0]
                quadrilaterals_[:, 1]=quadrilaterals_[:, 1] * np_image.shape[1] / output_shape_[0, 1]
                quadrilaterals_[:, 2]=quadrilaterals_[:, 2] * np_image.shape[0] / output_shape_[0, 0]
                quadrilaterals_[:, 3]=quadrilaterals_[:, 3] * np_image.shape[1] / output_shape_[0, 1]
                quadrilaterals_[:, 4]=quadrilaterals_[:, 4] * np_image.shape[0] / output_shape_[0, 0]
                quadrilaterals_[:, 5]=quadrilaterals_[:, 5] * np_image.shape[1] / output_shape_[0, 1]
                quadrilaterals_[:, 6]=quadrilaterals_[:, 6] * np_image.shape[0] / output_shape_[0, 0]
                quadrilaterals_[:, 7]=quadrilaterals_[:, 7] * np_image.shape[1] / output_shape_[0, 1]

                # image_with_bboxes=\
                #     drawing_toolbox.draw_bboxes_on_image(
                #         np_image.copy(),
                #         labels_,
                #         scores_,
                #         bboxes_,
                #         thickness=2)
                # imsave('./demo/' + FLAGS.image_file_name[:-4] + '_bboxes' + '.jpg',
                #        image_with_bboxes)
                image_with_quadrilaterals=\
                    drawing_toolbox.draw_quadrilaterals_on_image(
                        np_image.copy(),
                        labels_,
                        scores_,
                        quadrilaterals_,
                        thickness=2)
                imsave(FLAGS.storage_directory + image_file_name[:-4] + '_quadrilaterals' + '.jpg', image_with_quadrilaterals)
                
                y1, x1, y2, x2,\
                y3, x3, y4, x4=[int(e) for e in quadrilaterals_[0, :]]

                topLeftVertex = [x1, y1]
                topRightVertex = [x2, y2]
                bottomLeftVertex = [x4, y4]
                bottomRightVertex = [x3, y3]

                ymin=int(round(bboxes_[0, 0]))
                xmin=int(round(bboxes_[0, 1]))
                ymax=int(round(bboxes_[0, 2]))
                xmax=int(round(bboxes_[0, 3]))

                PLATE_WIDTH = xmax - xmin
                PLATE_HEIGHT = ymax - ymin

                pts1 = np.float32([topLeftVertex, topRightVertex, bottomLeftVertex, bottomRightVertex])
                pts2 = np.float32([[0, 0], [PLATE_WIDTH, 0], [0, PLATE_HEIGHT], [PLATE_WIDTH, PLATE_HEIGHT]])
            
                M = cv2.getPerspectiveTransform(pts1, pts2)
                cropped_image = cv2.warpPerspective(np_image.copy(), M, (PLATE_WIDTH, PLATE_HEIGHT))
                imsave(FLAGS.storage_directory + image_file_name[:-4] + '_cropped' + '.jpg', cropped_image)
            
            print('total_time: ', total_time)
示例#2
0
def main(_):
    with tf.Graph().as_default():

        def split_image_into_overlapped_images(image, n, r):
            """TODO: Docstring for split_image_into_overlapped_images.

            :image: TODO
            :n: TODO
            :r: TODO
            :returns: TODO

            """
            IH, IW = tf.shape(image)[0], tf.shape(image)[1]
            ny, nx = n
            ry, rx = r
            SH = tf.cast(
                tf.floordiv(tf.cast(IH, tf.float32), (ny - ny * ry + ry)),
                tf.int32)
            SW = tf.cast(
                tf.floordiv(tf.cast(IW, tf.float32), (nx - nx * rx + rx)),
                tf.int32)
            OH = tf.cast(ry * tf.cast(SH, tf.float32), tf.int32)
            OW = tf.cast(rx * tf.cast(SW, tf.float32), tf.int32)
            images = []
            os = []
            for i in range(ny):
                oy = i * (SH - OH)
                for j in range(nx):
                    ox = j * (SW - OW)
                    os.append([oy, ox])
                    images.append(image[oy:oy + SH, ox:ox + SW])
            return [[image, tf.shape(image), o]
                    for image, o in zip(images, os)]

        output_shape = [FLAGS.image_size] * 2

        input_image = tf.placeholder(tf.uint8, shape=(None, None, 3))
        # nr1 = [(2, 0.7), (4, 0.6), (8, 0.5)]
        # nr2 = [(4, 0.4)]  # no1
        # nr3 = [(4, 0.2)]
        # nr4 = [(4, 0.3)]
        # nr5 = [(4, 0.6)]
        # nr6 = [(4, 0.5)]
        # nr7 = [(8, 0.2)]
        # nr8 = [(8, 0.8)]
        # nr9 = [(8, 0.4)]  # no1
        # nr10 = [(2, 0.8)]
        # nr11 = [(2, 0.2)]
        # nr12 = [(2, 0.4)]
        # nr13 = [(2, 0.6)]  # no1
        # nr14 = [(2, 0.5)]
        # nr15 = [(2, 0.6), (4, 0.4)]  # select_threshold = 0.5
        nr16 = [(2, 0.6), (4, 0.4)]  # select_threshold = 0.95
        images, shapes, os =\
            zip(*([[image, shape, o]
                   for n, r in nr16
                   for image, shape, o in split_image_into_overlapped_images(
                       input_image,
                       (n, n),
                       (r, r))] + [[input_image,
                                    tf.shape(input_image), [0, 0]]]))
        # images = [images[0], images[1]]
        # shapes = [shapes[0], shapes[1]]
        # os = [os[0], os[1]]

        oys, oxs = zip(*os)
        shapes = tf.stack(shapes)
        oys = tf.stack(oys)
        oxs = tf.stack(oxs)
        oys = tf.expand_dims(oys, -1)
        oxs = tf.expand_dims(oxs, -1)

        features = []
        for image in images:
            features.append(
                textboxes_plusplus_preprocessing.preprocess_for_eval(
                    image,
                    None,
                    None,
                    output_shape,
                    data_format=FLAGS.data_format,
                    output_rgb=False))
        features = tf.stack(features, axis=0)
        output_shape =\
            tf.expand_dims(
                tf.constant(output_shape,
                            dtype=tf.int32),
                axis=0)  # (1, 2)

        with tf.variable_scope(FLAGS.model_scope,
                               default_name=None,
                               values=[features],
                               reuse=tf.AUTO_REUSE):
            with tf.device('/cpu:0'):
                anchor_processor =\
                    anchor_manipulator.AnchorProcessor(
                        positive_threshold=None,
                        ignore_threshold=None,
                        prior_scaling=config.PRIOR_SCALING)

                anchor_heights_all_layers,\
                    anchor_widths_all_layers,\
                    num_anchors_per_location_all_layers =\
                    anchor_processor.get_anchors_size_all_layers(
                        config.ALL_ANCHOR_SCALES,
                        config.ALL_EXTRA_SCALES,
                        config.ALL_ANCHOR_RATIOS,
                        config.NUM_FEATURE_LAYERS)
                # anchor_heights_all_layers: [1d-tf.constant tf.float32,
                #                           1d-tf.constant tf.float32,
                #                           ...]
                # anchor_widths_all_layers: [1d-tf.constant tf.float32,
                #                           1d-tf.constant tf.float32,
                #                           ...]
                # num_anchors_per_location_all_layers:
                #   [Python int, Python int, ...]

                anchors_ymin,\
                    anchors_xmin,\
                    anchors_ymax,\
                    anchors_xmax, _ =\
                    anchor_processor.get_all_anchors_all_layers(
                        tf.squeeze(output_shape, axis=0),
                        anchor_heights_all_layers,
                        anchor_widths_all_layers,
                        num_anchors_per_location_all_layers,
                        config.ANCHOR_OFFSETS,
                        config.VERTICAL_OFFSETS,
                        config.ALL_LAYER_SHAPES,
                        config.ALL_LAYER_STRIDES,
                        [0.] * config.NUM_FEATURE_LAYERS,
                        [False] * config.NUM_FEATURE_LAYERS)
                # anchors_ymin: 1d-tf.Tensor(num_anchors_all_layers) tf.float32

                backbone =\
                    textboxes_plusplus_net.VGG16Backbone(FLAGS.data_format)
                feature_layers = backbone.forward(features, training=False)
                # shape = (num_feature_layers,
                #          BS,
                #          FH,
                #          FW,
                #          feature_depth)

                location_predictions, class_predictions =\
                    textboxes_plusplus_net.multibox_head(
                        feature_layers,
                        FLAGS.num_classes,
                        config.NUM_OFFSETS,
                        num_anchors_per_location_all_layers,
                        data_format=FLAGS.data_format)
                # shape = (num_feature_layers,
                #          bs,
                #          fh,
                #          fw,
                #          num_anchors_per_loc * 2 * num_offsets)

                if FLAGS.data_format == 'channels_first':
                    class_predictions =\
                        [tf.transpose(pred,
                                      [0, 2, 3, 1])
                         for pred in class_predictions]
                    location_predictions =\
                        [tf.transpose(pred,
                                      [0, 2, 3, 1])
                         for pred in location_predictions]
                class_predictions =\
                    [tf.reshape(pred,
                                [len(images), -1, FLAGS.num_classes])
                     for pred in class_predictions]
                location_predictions =\
                    [tf.reshape(pred, [len(images), -1, config.NUM_OFFSETS])
                     for pred in location_predictions]
                # shape = (num_feature_layers,
                #          bs,
                #          fh * fw * num_anchors_per_loc * 2,
                #          num_offsets)

                class_predictions = tf.concat(class_predictions, axis=1)
                location_predictions = tf.concat(location_predictions, axis=1)

                # total_parameters = 0
                # for variable in tf.trainable_variables():
                #     # shape is an array of tf.Dimension
                #     shape = variable.get_shape()
                #     print(shape)
                #     print(len(shape))
                #     variable_parameters = 1
                #     for dim in shape:
                #         print(dim)
                #         variable_parameters *= dim.value
                #     print(variable_parameters)
                #     total_parameters += variable_parameters
                # print(total_parameters)

        with tf.device('/cpu:0'):
            bboxes_pred, quadrilaterals_pred =\
                anchor_processor.batch_decode_anchors(
                    location_predictions,
                    anchors_ymin,
                    anchors_xmin,
                    anchors_ymax,
                    anchors_xmax)

            bboxes_ymin =\
                tf.cast(bboxes_pred[:, :, 0] * tf.expand_dims(tf.cast(
                    tf.truediv(shapes[:, 0],
                               output_shape[0, 0]),
                    tf.float32
                ), -1), tf.int32) + oys
            bboxes_xmin =\
                tf.cast(bboxes_pred[:, :, 1] * tf.expand_dims(tf.cast(
                    tf.truediv(shapes[:, 1],
                               output_shape[0, 1]),
                    tf.float32
                ), -1), tf.int32) + oxs
            bboxes_ymax =\
                tf.cast(bboxes_pred[:, :, 2] * tf.expand_dims(tf.cast(
                    tf.truediv(shapes[:, 0],
                               output_shape[0, 0]),
                    tf.float32), -1), tf.int32) + oys
            bboxes_xmax =\
                tf.cast(bboxes_pred[:, :, 3] * tf.expand_dims(tf.cast(
                    tf.truediv(shapes[:, 1],
                               output_shape[0, 1]),
                    tf.float32), -1), tf.int32) + oxs
            bboxes_pred =\
                tf.reshape(
                    tf.stack([bboxes_ymin, bboxes_xmin,
                              bboxes_ymax, bboxes_xmax], -1),
                    shape=[-1, 4])
            quadrilaterals_y1 =\
                tf.cast(
                    quadrilaterals_pred[:, :, 0] * tf.expand_dims(
                        tf.cast(tf.truediv(shapes[:, 0],
                                           output_shape[0, 0]),
                                tf.float32), -1), tf.int32) + oys
            quadrilaterals_x1 =\
                tf.cast(
                    quadrilaterals_pred[:, :, 1] * tf.expand_dims(
                        tf.cast(tf.truediv(shapes[:, 1],
                                           output_shape[0, 1]),
                                tf.float32), -1), tf.int32) + oxs
            quadrilaterals_y2 =\
                tf.cast(
                    quadrilaterals_pred[:, :, 2] * tf.expand_dims(
                        tf.cast(tf.truediv(shapes[:, 0],
                                           output_shape[0, 0]),
                                tf.float32), -1), tf.int32) + oys
            quadrilaterals_x2 =\
                tf.cast(
                    quadrilaterals_pred[:, :, 3] * tf.expand_dims(
                        tf.cast(tf.truediv(shapes[:, 1],
                                           output_shape[0, 1]),
                                tf.float32), -1), tf.int32) + oxs
            quadrilaterals_y3 =\
                tf.cast(
                    quadrilaterals_pred[:, :, 4] * tf.expand_dims(
                        tf.cast(tf.truediv(shapes[:, 0],
                                           output_shape[0, 0]),
                                tf.float32), -1), tf.int32) + oys
            quadrilaterals_x3 =\
                tf.cast(
                    quadrilaterals_pred[:, :, 5] * tf.expand_dims(
                        tf.cast(tf.truediv(shapes[:, 1],
                                           output_shape[0, 1]),
                                tf.float32), -1), tf.int32) + oxs
            quadrilaterals_y4 =\
                tf.cast(
                    quadrilaterals_pred[:, :, 6] * tf.expand_dims(
                        tf.cast(tf.truediv(shapes[:, 0],
                                           output_shape[0, 0]),
                                tf.float32), -1), tf.int32) + oys
            quadrilaterals_x4 =\
                tf.cast(
                    quadrilaterals_pred[:, :, 7] * tf.expand_dims(
                        tf.cast(tf.truediv(shapes[:, 1],
                                           output_shape[0, 1]),
                                tf.float32), -1), tf.int32) + oxs
            quadrilaterals_pred =\
                tf.reshape(
                    tf.stack([quadrilaterals_y1,
                              quadrilaterals_x1,
                              quadrilaterals_y2,
                              quadrilaterals_x2,
                              quadrilaterals_y3,
                              quadrilaterals_x3,
                              quadrilaterals_y4,
                              quadrilaterals_x4], -1),
                    shape=[-1, 8])
            class_predictions = tf.reshape(class_predictions,
                                           shape=[-1, FLAGS.num_classes])
            bboxes_pred = tf.cast(bboxes_pred, tf.float32)
            quadrilaterals_pred = tf.cast(quadrilaterals_pred, tf.float32)

            selected_bboxes,\
                selected_quadrilaterals,\
                selected_scores =\
                bbox_util.parse_by_class(
                    tf.shape(input_image)[:2],
                    class_predictions,
                    bboxes_pred,
                    quadrilaterals_pred,
                    FLAGS.num_classes,
                    FLAGS.select_threshold,
                    FLAGS.min_size,
                    FLAGS.keep_topk,
                    FLAGS.nms_topk,
                    FLAGS.nms_threshold)

            labels_list = []
            scores_list = []
            bboxes_list = []
            quadrilaterals_list = []
            for k, v in selected_scores.items():
                labels_list.append(tf.ones_like(v, tf.int32) * k)
                scores_list.append(v)
                bboxes_list.append(selected_bboxes[k])
                quadrilaterals_list.append(selected_quadrilaterals[k])
            all_labels = tf.concat(labels_list, axis=0)
            all_scores = tf.concat(scores_list, axis=0)
            all_bboxes = tf.concat(bboxes_list, axis=0)
            all_quadrilaterals = tf.concat(quadrilaterals_list, axis=0)

        saver = tf.train.Saver()
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            saver.restore(sess, get_checkpoint())

            image_paths =\
                sorted(
                    [path
                     for pattern in FLAGS.input_image_stem_patterns.split(',')
                     for path in Path(FLAGS.input_image_root).glob(pattern)],
                    key=lambda e: int(re.findall(r'(?<=_)\d+(?=.)',
                                                 e.name)[0]))
            for i, image_path in enumerate(image_paths):
                # image = imread(str(image_path))
                image =\
                    cv2.imread(
                        str(image_path),
                        cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR
                    )[:, :, ::-1]
                start_time = time.time()

                labels_,\
                    scores_,\
                    bboxes_,\
                    quadrilaterals_ =\
                    sess.run([all_labels,
                              all_scores,
                              all_bboxes,
                              all_quadrilaterals,
                              ],
                             feed_dict={input_image: image})

                elapsed_time = time.time() - start_time
                print('{}: elapsed_time = {}'.format(i + 1, elapsed_time))
                annotation_file_name =\
                    'task1_' + image_path.name.replace('.jpg', '.txt')
                with open(
                        Path(FLAGS.output_directory).joinpath(
                            annotation_file_name), 'w') as f:
                    num_predicted_text_lines = np.shape(quadrilaterals_)[0]
                    for i in range(num_predicted_text_lines):
                        y1, x1, y2, x2,\
                            y3, x3, y4, x4 =\
                            [int(e) for e in quadrilaterals_[i, :]]
                        score = float(scores_[i])
                        if (y1 == 0 and x1 == 0 and y2 == 0 and x2 == 0
                                and y3 == 0 and x3 == 0 and y4 == 0 and x4 == 0
                                and score == 0.0):
                            continue
                        f.write('{},{},{},{},{},{},{},{},{}\n'.format(
                            x1, y1, x2, y2, x3, y3, x4, y4, score))
def model_fn(features, labels, mode, params):
    # shape = labels['shape']
    loc_targets = labels['loc_targets']  # (bs, n_anchors_all_layers, n_ofsets)
    cls_targets = labels['cls_targets']  # (bs, n_anchors_all_layers)
    # match_scores = labels['match_scores']  # (bs, n_anchors_all_layers)

    global global_anchor_info
    # decode_fn = global_anchor_info['decode_fn']
    # num_anchors_per_layer = global_anchor_info['num_anchors_per_layer']
    num_anchors_per_location_all_layers =\
        global_anchor_info['num_anchors_per_location_all_layers']

    with tf.variable_scope(params['model_scope'],
                           default_name=None,
                           values=[features],
                           reuse=tf.AUTO_REUSE):

        # features: 4d-tf.Tensor-(bs, n_channels, fH, fW)-tf.float32
        backbone = textboxes_plusplus_net.VGG16Backbone(params['data_format'])
        feature_layers =\
            backbone.forward(features,
                             training=(mode == tf.estimator.ModeKeys.TRAIN))

        # shape = (num_feature_layers,
        #          bs,
        #          num_anchors_per_loc * 2 * num_offsets,
        #          fh,
        #          fw)
        location_predictions, class_predictions =\
            textboxes_plusplus_net.multibox_head(
                feature_layers,
                params['num_classes'],
                config.NUM_OFFSETS,
                num_anchors_per_location_all_layers,
                data_format=params['data_format'])

        # shape = (num_feature_layers,
        #          bs,
        #          fh,
        #          fw,
        #          num_anchors_per_loc * 2 * num_offsets)
        if params['data_format'] == 'channels_first':
            location_predictions =\
                [tf.transpose(pred,
                              [0, 2, 3, 1]) for pred in location_predictions]
            class_predictions =\
                [tf.transpose(pred,
                              [0, 2, 3, 1]) for pred in class_predictions]
            # if channels_first ==> move channel to last

        # shape = (num_feature_layers,
        #          bs,
        #          num_anchors_per_layer=fh * fw * num_anchors_per_loc * 2,
        #          num_offsets)
        location_predictions = [tf.reshape(pred,
                                           [tf.shape(features)[0],
                                            -1,
                                            config.NUM_OFFSETS])
                                for pred in location_predictions]
        class_predictions = [tf.reshape(pred,
                                        [tf.shape(features)[0],
                                         -1,
                                         params['num_classes']])
                             for pred in class_predictions]

        # shape = (bs,
        #          num_anchors_all_layers,
        #          num_offsets)
        location_predictions = tf.concat(location_predictions, axis=1)
        class_predictions = tf.concat(class_predictions, axis=1)

        # shape = (num_anchors_per_batch,
        #          num_offsets)
        location_predictions = tf.reshape(location_predictions,
                                          [-1, config.NUM_OFFSETS])
        class_predictions = tf.reshape(class_predictions,
                                       [-1, params['num_classes']])

    with tf.device('/cpu:0'):
        with tf.control_dependencies([class_predictions,
                                      location_predictions]):
            with tf.name_scope('post_forward'):
                # decoded_location_predictions =\
                #     decode_fn(tf.reshape(location_predictions,
                #                          [tf.shape(features)[0],
                #                           -1,
                #                           config.NUM_OFFSETS]))
                # decoded_location_predictions =\
                #     tf.reshape(decoded_location_predictions,
                #                [-1, config.NUM_OFFSETS])

                # - location_predictions[i, :] contains:
                # + bebore decode_fn:
                # [pred_cy, pred_cx, pred_h, pred_w, pred_y1, pred_x1, ...]
                # + after decode_fn:
                # [pred_ymin*, pred_xmin*, pred_ymax*, pred_xmax*, pred_y1*,
                # ...]
                # in which * means decoded value

                # e.g., cls_targets.get_shape():  (bs, n_anchors)
                # e.g., loc_targets.get_shape():  (bs, n_anchors, n_offsets)
                flaten_cls_targets = tf.reshape(cls_targets, [-1])
                # flaten_match_scores = tf.reshape(match_scores, [-1])
                flaten_loc_targets = tf.reshape(loc_targets,
                                                [-1, config.NUM_OFFSETS])
                # - loc_targets:
                # + gt_target 0 for negatives and ignores
                # + gt_target otherwise for object (positives and labeled bg)
                # - cls_targets:
                # + gt_label -1 for ignores
                # + gt_label 0 for labeled background (usually empty) and
                # negatives considered as background
                # + gt_label > 0 for detection object (positives)
                # - match_scores:
                # + gt_score >= 0

                # Each positive example has one label.
                # shape = (num_anchors_per_batch, )
                positive_mask = flaten_cls_targets > 0
                # shape = ()
                # n_positives = tf.count_nonzero(positive_mask)

                # shape = (bs, )
                # batch_n_positives = [n_positives_0, ..., n_positives_bs-1]
                batch_n_positives = tf.count_nonzero(cls_targets > 0, -1)

                # shape = (bs, num_anchors_all_layers)
                batch_negtive_mask = tf.equal(cls_targets, 0)
                # shape = (bs, )
                batch_n_negtives = tf.count_nonzero(batch_negtive_mask, -1)

                # shape = (bs, )
                batch_n_neg_select =\
                    tf.to_int32(params['negative_ratio'] *
                                tf.to_float(batch_n_positives))
                # shape = (bs, )
                batch_n_neg_select =\
                    tf.minimum(batch_n_neg_select,
                               tf.to_int32(batch_n_negtives))

                # hard negative mining for classification
                # class_predictions.get_shape(): (num_anchors_per_batch,
                #                                 num_classes)
                # shape = (bs, num_anchors_all_layers)
                predictions_for_bg =\
                    tf.nn.softmax(tf.reshape(class_predictions,
                                             [tf.shape(features)[0],
                                              -1,
                                              params['num_classes']]))[:, :, 0]
                # shape = (bs, num_anchors_all_layers)
                prob_for_negtives =\
                    tf.where(batch_negtive_mask,
                             0. - predictions_for_bg,
                             # ignore all the positives
                             0. - tf.ones_like(predictions_for_bg))
                # shape = (bs, num_anchors_all_layers)
                # rearrange the anchors according to the prob for bg.
                topk_prob_for_bg, _ =\
                    tf.nn.top_k(prob_for_negtives,
                                k=tf.shape(prob_for_negtives)[1])
                # shape = (bs, )
                score_at_k =\
                    tf.gather_nd(topk_prob_for_bg,
                                 tf.stack([tf.range(tf.shape(features)[0]),
                                           batch_n_neg_select - 1],
                                          axis=-1))
                # tf.stack =
                # [
                #   [0, n_negatives_0 - 1],
                #   [1, n_negatives_1 - 1],
                #   ...
                #   [bs - 1, n_negatives_bs-1 - 1],
                # ]
                # topk_prob_for_bg =
                #            n_negatives_0 - 1
                #                    | n_negatives_1 - 1
                #                    |       |  n_negatives_bs-1 - 1
                #                    |       |         |
                #                   \/      \/        \/
                # [        0        1       2         x     y   n_anchors-1
                #   0   [-0.001, -0.002, -0.01, ..., -1,   -1,   -1]
                #   1   [-0.002, -0.008, -0.05, ..., -0.7, -1,   -1]
                #       ...
                #  bs-1 [-0.05,  -0.09,  -0.1, ...,  -0.9, -1,   -1]
                # ]
                # NOTE: n_negatives_i never points to -1 because
                # batch_n_neg_select = tf.minimum(batch_n_neg_select,
                # batch_n_negtives)
                # score_at_k =
                # [    0       1         bs-1
                #   -0.002, -0.05, ..., -0.9
                # ]

                # shape = (bs, num_anchors_all_layers)
                selected_neg_mask =\
                    prob_for_negtives >= tf.expand_dims(score_at_k,
                                                        axis=-1)
                # selected_neg_mask =
                # [
                #  original_order[True, True, False, ..., False, False, False]
                #  original_order[True, True, True, False, ...,  False, False]
                #                 ...
                #  original_order[True, True, True, ...,  True,  False, False]
                # ]

                # include both selected negtive and all positive examples
                # Training is not allowed to change value of mask each time a
                # new batch is fetched. Model depends on mask to change
                # weights, the opposite is wrong.
                final_mask =\
                    tf.stop_gradient(
                        tf.logical_or(
                            tf.reshape(
                                tf.logical_and(
                                    batch_negtive_mask,
                                    selected_neg_mask),
                                [-1]),
                            positive_mask))

                # shape = (n_positive_anchors_per_batch +
                #          n_chosen_negative_anchors_per_batch, num_classes)
                class_predictions = tf.boolean_mask(class_predictions,
                                                    final_mask)
                # class_predictions[i, :] != 0 if anchor_i is positive anchor
                # or selected negative anchor else = 0

                # shape = (n_positive_anchors_per_batch, num_offsets)
                location_predictions =\
                    tf.boolean_mask(location_predictions,
                                    tf.stop_gradient(positive_mask))
                # shape = (n_positive_anchors_per_batch +
                #          n_chosen_negative_anchors_per_batch, )
                flaten_cls_targets =\
                    tf.boolean_mask(  # filter out unused negatives
                        tf.clip_by_value(  # consider ignores as background
                            flaten_cls_targets,
                            0,
                            params['num_classes']),
                        final_mask)
                # shape = (n_positive_anchors_per_batch, num_offsets)
                flaten_loc_targets =\
                    tf.stop_gradient(
                        tf.boolean_mask(
                            flaten_loc_targets,
                            positive_mask))

                # location_predictions is from model, flaten_loc_targets is
                # from data
                predictions = {
                    'classes': tf.argmax(class_predictions, axis=-1),
                    'probabilities': tf.reduce_max(
                        tf.nn.softmax(
                            class_predictions,
                            name='softmax_tensor'),
                        axis=-1),
                    # 'loc_predict': decoded_location_predictions
                }

                cls_accuracy =\
                    tf.metrics.accuracy(flaten_cls_targets,
                                        predictions['classes'])
                cls_precision =\
                    tf.metrics.precision(flaten_cls_targets,
                                         predictions['classes'])
                cls_recall =\
                    tf.metrics.recall(flaten_cls_targets,
                                      predictions['classes'])
                metrics = {'cls_accuracy': cls_accuracy,
                           'cls_precision': cls_precision,
                           'cls_recall': cls_recall}

                # for logging purposes
                tf.identity(cls_accuracy[1], name='cls_accuracy')
                tf.summary.scalar('cls_accuracy', cls_accuracy[1])
                tf.identity(cls_precision[1], name='cls_precision')
                tf.summary.scalar('cls_precision', cls_precision[1])
                tf.identity(cls_recall[1], name='cls_recall')
                tf.summary.scalar('cls_recall', cls_recall[1])

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # - function name differences:
    # + without 'sparse_': labels is one-hot encoded tensor (n_examples,
    # n_classes)
    # + with 'sparse_': labels is tensor (n_examples, )
    # + without '_with_logits': return mean([loss(example) for example in
    # examples])
    # + with '_with_logits': return [loss(example) for example in examples]
    # NOTE: cross_entropy function calculates softmax(logits) inside.
    # shape = ()
    # average class loss of all examples
    cross_entropy =\
        tf.losses.sparse_softmax_cross_entropy(
            labels=flaten_cls_targets,
            logits=class_predictions) *\
        (params['negative_ratio'] + 1.)
    # create a tensor named cross_entropy for logging purposes
    tf.identity(cross_entropy,
                name='cross_entropy_loss')
    tf.summary.scalar('cross_entropy_loss',
                      cross_entropy)

    loc_loss =\
        modified_smooth_l1(location_predictions,
                           flaten_loc_targets,
                           sigma=1.)

    # average location loss of all positive anchors
    loc_loss = tf.reduce_mean(tf.reduce_sum(loc_loss,
                                            axis=-1),
                              name='location_loss')
    tf.summary.scalar('location_loss', loc_loss)
    tf.losses.add_loss(loc_loss)

    l2_loss_vars = []
    for trainable_var in tf.trainable_variables():
        if '_bn' not in trainable_var.name:
            if 'conv4_3_scale' not in trainable_var.name:
                l2_loss_vars.append(tf.nn.l2_loss(trainable_var))
            else:
                l2_loss_vars.append(tf.nn.l2_loss(trainable_var) * 0.1)
    # add weight decay to the loss
    # We exclude the batch norm variables because doing so leads to a small
    # improvement in accuracy.
    total_loss =\
        tf.add(cross_entropy + loc_loss,
               tf.multiply(params['weight_decay'],
                           tf.add_n(l2_loss_vars),
                           name='l2_loss'),
               name='total_loss')

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [params['learning_rate'] * decay
                     for decay in params['lr_decay_factors']]
        learning_rate =\
            tf.train.piecewise_constant(
                tf.cast(global_step, tf.int32),
                [int(_) for _ in params['decay_boundaries']],
                lr_values)
        truncated_learning_rate =\
            tf.maximum(learning_rate,
                       tf.constant(params['end_learning_rate'],
                                   dtype=learning_rate.dtype),
                       name='learning_rate')
        # create a tensor named learning_rate for logging purposes
        tf.summary.scalar('learning_rate', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(
            learning_rate=truncated_learning_rate,
            momentum=params['momentum'])
        optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(total_loss, global_step)
    else:
        train_op = None

    # used to train from scratch or finetune from its checkpoint file (it
    # means there are no differences in terms of network architecture between
    # the network we build now and the network stored in checkpoint file)
    # return tf.estimator.EstimatorSpec(
    #                         mode=mode,
    #                         predictions=predictions,
    #                         loss=total_loss,
    #                         train_op=train_op,
    #                         eval_metric_ops=metrics,
    #                         scaffold=tf.train.Scaffold(None))
    # used to finetune from other models
    return tf.estimator.EstimatorSpec(
                            mode=mode,
                            predictions=predictions,
                            loss=total_loss,
                            train_op=train_op,
                            eval_metric_ops=metrics,
                            scaffold=tf.train.Scaffold(init_fn=get_init_fn()))
def model_fn(features, labels, mode, params):
    file_name = features['file_name']
    file_name = tf.identity(file_name, name='file_name')
    shape = features['shape']
    output_shape = features['output_shape']
    image = features['image']

    anchor_processor = anchor_manipulator.AnchorProcessor(
        positive_threshold=None,
        ignore_threshold=None,
        prior_scaling=config.PRIOR_SCALING)
    with tf.variable_scope(params['model_scope'],
                           default_name=None,
                           values=[image],
                           reuse=tf.AUTO_REUSE):
        with tf.device('/cpu:0'):
            anchor_heights_all_layers,\
            anchor_widths_all_layers,\
            num_anchors_per_location_all_layers=\
                anchor_processor.get_anchors_size_all_layers(
                    config.ALL_ANCHOR_SCALES,
                    config.ALL_EXTRA_SCALES,
                    config.ALL_ANCHOR_RATIOS,
                    config.NUM_FEATURE_LAYERS)

            anchors_ymin,\
            anchors_xmin,\
            anchors_ymax,\
            anchors_xmax,\
            _=\
                anchor_processor.get_all_anchors_all_layers(
                    tf.squeeze(output_shape, axis=0),
                    anchor_heights_all_layers,
                    anchor_widths_all_layers,
                    num_anchors_per_location_all_layers,
                    config.ANCHOR_OFFSETS,
                    config.VERTICAL_OFFSETS,
                    config.ALL_LAYER_SHAPES,
                    config.ALL_LAYER_STRIDES,
                    [0.] * config.NUM_FEATURE_LAYERS,
                    [False] * config.NUM_FEATURE_LAYERS)

            backbone=\
                textboxes_plusplus_net.VGG16Backbone(params['data_format'])
            feature_layers = backbone.forward(
                image, training=(mode == tf.estimator.ModeKeys.TRAIN))
            location_predictions, class_predictions=\
                textboxes_plusplus_net.multibox_head(
                    feature_layers,
                    params['num_classes'],
                    config.NUM_OFFSETS,
                    num_anchors_per_location_all_layers,
                    data_format=params['data_format'])
            if params['data_format'] == 'channels_first':
                location_predictions=\
                    [tf.transpose(pred, [0, 2, 3, 1])\
                    for pred in location_predictions]
                class_predictions=\
                    [tf.transpose(pred, [0, 2, 3, 1])\
                     for pred in class_predictions]

            location_predictions=\
                [tf.reshape(pred,
                            [tf.shape(image)[0],
                             -1,
                             config.NUM_OFFSETS])\
                 for pred in location_predictions]
            class_predictions=\
                [tf.reshape(pred,
                            [tf.shape(image)[0],
                            -1,
                            params['num_classes']])\
                for pred in class_predictions]

            location_predictions = tf.concat(location_predictions, axis=1)
            class_predictions = tf.concat(class_predictions, axis=1)

            location_predictions = tf.reshape(location_predictions,
                                              [-1, config.NUM_OFFSETS])
            class_predictions = tf.reshape(class_predictions,
                                           [-1, params['num_classes']])
    with tf.device('/cpu:0'):
        bboxes_pred,\
        quadrilaterals_pred=\
            anchor_processor.decode_anchors(
                location_predictions,
                anchors_ymin,
                anchors_xmin,
                anchors_ymax,
                anchors_xmax)
        selected_bboxes,\
        selected_quadrilaterals,\
        selected_scores=\
            bbox_util.parse_by_class(
                tf.squeeze(output_shape, axis=0),
                class_predictions,
                bboxes_pred,
                quadrilaterals_pred,
                params['num_classes'],
                params['select_threshold'],
                params['min_size'],
                params['keep_topk'],
                params['nms_topk'],
                params['nms_threshold'])

    labels_list = []
    scores_list = []
    bboxes_list = []
    quadrilaterals_list = []
    for k, v in selected_scores.items():
        labels_list.append(tf.ones_like(v, tf.int32) * k)
        scores_list.append(v)
        bboxes_list.append(selected_bboxes[k])
        quadrilaterals_list.append(selected_quadrilaterals[k])
    all_labels = tf.concat(labels_list, axis=0)
    all_scores = tf.concat(scores_list, axis=0)
    all_bboxes = tf.concat(bboxes_list, axis=0)
    all_quadrilaterals = tf.concat(quadrilaterals_list, axis=0)

    save_image_op=\
        tf.py_func(save_image_with_labels,
                   [textboxes_plusplus_preprocessing.unwhiten_image(
                       tf.squeeze(image, axis=0),
                       output_rgb=False),
                    all_labels * tf.to_int32(all_scores > 0.3),
                    all_scores,
                    all_bboxes,
                    all_quadrilaterals],
                   tf.int64,
                   stateful=True)
    tf.identity(save_image_op, name='save_image_op')
    predictions=\
        {'file_name': file_name,
         'shape': shape,
         'output_shape': output_shape}
    for class_ind in range(1, params['num_classes']):
        predictions['scores_{}'.format(class_ind)]=\
            tf.expand_dims(selected_scores[class_ind], axis=0)
        predictions['bboxes_{}'.format(class_ind)]=\
            tf.expand_dims(selected_bboxes[class_ind], axis=0)
        predictions['quadrilaterals_{}'.format(class_ind)]=\
            tf.expand_dims(selected_quadrilaterals[class_ind], axis=0)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          prediction_hooks=None,
                                          loss=None,
                                          train_op=None)
    else:
        raise ValueError('This script only support "PREDICT" mode!')