def rotate():
     return tf_image.random_rotate90(image, bboxes, xs, ys)
 def rotate():
     return tf_image.random_rotate90(image, bboxes, xs, ys)
def preprocess_for_train(image,
                         labels,
                         bboxes,
                         xs,
                         ys,
                         out_shape,
                         data_format='NHWC',
                         scope='ssd_preprocessing_train'):
    """Preprocesses the given image for training.

    Note that the actual resizing scale is sampled from
        [`resize_size_min`, `resize_size_max`].

    Args:
        image: A `Tensor` representing an image of arbitrary size.
        output_height: The height of the image after preprocessing.
        output_width: The width of the image after preprocessing.
        resize_side_min: The lower bound for the smallest side of the image for
            aspect-preserving resizing.
        resize_side_max: The upper bound for the smallest side of the image for
            aspect-preserving resizing.

    Returns:
        A preprocessed image.
    """
    fast_mode = False
    with tf.name_scope(scope, 'ssd_preprocessing_train',
                       [image, labels, bboxes]):
        if image.get_shape().ndims != 3:
            raise ValueError('Input must be of size [height, width, C>0]')

        # Randomly flip the image horizontally.
        if FLIP:
            image, bboxes, xs, ys = tf_image.random_flip_left_right_bbox(
                image, bboxes, xs, ys)
        if ROTATE:
            # random rotate the image [-10, 10]
            image, bboxes, xs, ys = tf_rotate_image(image, xs, ys)

        # samples = tf.multinomial(tf.log([[0.25, 0.25, 0.25, 0.25]]), 1) # note log-prob
        # scale=elems[tf.cast(samples[0][0], tf.int32)]
        # if SCALE:
        #     image,bboxes,xs,ys=tf_scale_image(image,bboxes,xs,ys,640)

        image_shape = tf.cast(tf.shape(image), dtype=tf.float32)
        image_h, image_w = image_shape[0], image_shape[1]

        if USE_NM_CROP:
            mask = tf.greater_equal(labels, 1)
            valid_bboxes = tf.boolean_mask(bboxes, mask)
            # FIXME bboxes may is empty
            # NOTE tf_func must return value must be numpy, or you will madding!!!!!!
            crop_bbox = tf.py_func(generate_sample,
                                   [image_shape, valid_bboxes], tf.float32)
        else:
            scales = tf.random_shuffle([0.5, 1.])
            scales = tf.Print(scales, [crop])
            target_h = tf.cast(640 / scales[0], dtype=tf.float32)
            target_w = tf.cast(640 / scales[0], dtype=tf.float32)
            bbox_begin_h_max = tf.maximum(image_h - target_h, 0)
            bbox_begin_w_max = tf.maximum(image_w - target_w, 0)
            bbox_begin_h = tf.random_uniform([],
                                             minval=0,
                                             maxval=bbox_begin_h_max,
                                             dtype=tf.float32)
            bbox_begin_w = tf.random_uniform([],
                                             minval=0,
                                             maxval=bbox_begin_w_max,
                                             dtype=tf.float32)

            crop_bbox = [bbox_begin_h/image_h, bbox_begin_w/image_w, \
                (bbox_begin_h+target_h)/image_h, (bbox_begin_w+target_w)/image_w]

        image = tf.image.crop_and_resize(tf.expand_dims(image, 0), [crop_bbox],
                                         [0], (640, 640),
                                         extrapolation_value=128)
        image = tf.squeeze(image, 0)
        bboxes, xs, ys = tfe.bboxes_resize(crop_bbox, bboxes, xs, ys)
        labels, bboxes, xs, ys = tfe.bboxes_filter_overlap(
            labels,
            bboxes,
            xs,
            ys,
            threshold=BBOX_CROP_OVERLAP,
            assign_value=LABEL_IGNORE)

        if ROTATE_90:
            rnd = tf.random_uniform((), minval=0, maxval=1)
            image, bboxes, xs, ys = tf.cond(
                tf.less(rnd, 0.2),
                lambda: tf_image.random_rotate90(image, bboxes, xs, ys),
                lambda: (image, bboxes, xs, ys))

        # tf_summary_image(tf.to_float(image), bboxes, 'crop_image')

        # what is the enpand's meanoing?
        # expand image
        if MAX_EXPAND_SCALE > 1:
            rnd2 = tf.random_uniform((), minval=0, maxval=1)

            def expand():
                scale = tf.random_uniform([],
                                          minval=1.0,
                                          maxval=MAX_EXPAND_SCALE,
                                          dtype=tf.float32)
                image_shape = tf.cast(tf.shape(image), dtype=tf.float32)
                image_h, image_w = image_shape[0], image_shape[1]
                target_h = tf.cast(image_h * scale, dtype=tf.int32)
                target_w = tf.cast(image_w * scale, dtype=tf.int32)
                tf.logging.info('expanded')
                return tf_image.resize_image_bboxes_with_crop_or_pad(
                    image, bboxes, xs, ys, target_h, target_w)

            def no_expand():
                return image, bboxes, xs, ys

            image, bboxes, xs, ys = tf.cond(tf.less(rnd2, config.expand_prob),
                                            expand, no_expand)

        # Convert to float scaled [0, 1].
        # if image.dtype != tf.float32:
        # image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        # tf_summary_image(image, bboxes, 'image_with_bboxes')

        # Distort image and bounding boxes.
        dst_image = image
        # use tf.image.sample_distorted_bounding_box() random crop train patch, but can't control the scale
        if False:
            dst_image, labels, bboxes, xs, ys, distort_bbox = \
                distorted_bounding_box_crop(image, labels, bboxes, xs, ys,
                                            min_object_covered=MIN_OBJECT_COVERED,
                                            aspect_ratio_range=CROP_ASPECT_RATIO_RANGE,
                                            area_range=AREA_RANGE)
            # Resize image to output size.
            dst_image = tf_image.resize_image(
                dst_image,
                out_shape,
                method=tf.image.ResizeMethod.BILINEAR,
                align_corners=False)

        # Filter bboxes using the length of shorter sides
        if USING_SHORTER_SIDE_FILTERING:
            xs = xs * out_shape[1]
            ys = ys * out_shape[0]
            labels, bboxes, xs, ys = tfe.bboxes_filter_by_shorter_side(
                labels,
                bboxes,
                xs,
                ys,
                min_height=MIN_SHORTER_SIDE,
                max_height=MAX_SHORTER_SIDE,
                assign_value=LABEL_IGNORE)
            xs = xs / out_shape[1]
            ys = ys / out_shape[0]

        # Randomly distort the colors. There are 4 ways to do it.
        dst_image = apply_with_random_selector(
            dst_image / 255.0,
            lambda x, ordering: distort_color(x, ordering, fast_mode),
            num_cases=4)
        dst_image = dst_image * 255.
        # tf_summary_image(dst_image, bboxes, 'image_color_distorted')

        # FIXME: change the input value
        # NOTE: resnet v1 use VGG data process, so we use the same way
        image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN])
        # Image data format.
        if data_format == 'NCHW':
            image = tf.transpose(image, perm=(2, 0, 1))
        return image, labels, bboxes, xs, ys