def test_create_tf_example(self):
        image_file_name = 'tmp_image.jpg'
        image_data = np.random.rand(256, 256, 3)
        tmp_dir = self.get_temp_dir()
        save_path = os.path.join(tmp_dir, image_file_name)
        image = PIL.Image.fromarray(image_data, 'RGB')
        image.save(save_path)

        image = {
            'file_name': image_file_name,
            'height': 256,
            'width': 256,
            'id': 11,
        }

        annotations_list = [{
            'area': .5,
            'iscrowd': False,
            'image_id': 11,
            'bbox': [64, 64, 128, 128],
            'category_id': 2,
            'id': 1000,
        }]

        image_dir = tmp_dir
        category_index = {
            1: {
                'name': 'dog',
                'id': 1
            },
            2: {
                'name': 'cat',
                'id': 2
            },
            3: {
                'name': 'human',
                'id': 3
            }
        }

        (_, example,
         num_annotations_skipped) = create_coco_tfrecord.create_tf_example(
             image, annotations_list, image_dir, category_index)

        self.assertEqual(num_annotations_skipped, 0)
        self._assertProtoEqual(
            example.features.feature['image/height'].int64_list.value, [256])
        self._assertProtoEqual(
            example.features.feature['image/width'].int64_list.value, [256])
        self._assertProtoEqual(
            example.features.feature['image/filename'].bytes_list.value,
            [six.b(image_file_name)])
        self._assertProtoEqual(
            example.features.feature['image/source_id'].bytes_list.value,
            [six.b(str(image['id']))])
        self._assertProtoEqual(
            example.features.feature['image/format'].bytes_list.value,
            [six.b('jpeg')])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmin'].float_list.
            value, [0.25])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymin'].float_list.
            value, [0.25])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmax'].float_list.
            value, [0.75])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymax'].float_list.
            value, [0.75])
        self._assertProtoEqual(
            example.features.feature['image/object/class/text'].bytes_list.
            value, [six.b('cat')])
    def test_create_tf_example_with_instance_masks(self):
        image_file_name = 'tmp_image.jpg'
        image_data = np.random.rand(8, 8, 3)
        tmp_dir = self.get_temp_dir()
        save_path = os.path.join(tmp_dir, image_file_name)
        image = PIL.Image.fromarray(image_data, 'RGB')
        image.save(save_path)

        image = {
            'file_name': image_file_name,
            'height': 8,
            'width': 8,
            'id': 11,
        }

        annotations_list = [{
            'area':
            .5,
            'iscrowd':
            False,
            'image_id':
            11,
            'bbox': [0, 0, 8, 8],
            'segmentation': [[4, 0, 0, 0, 0, 4], [8, 4, 4, 8, 8, 8]],
            'category_id':
            1,
            'id':
            1000,
        }]

        image_dir = tmp_dir
        category_index = {
            1: {
                'name': 'dog',
                'id': 1
            },
        }

        (_, example,
         num_annotations_skipped) = create_coco_tfrecord.create_tf_example(
             image,
             annotations_list,
             image_dir,
             category_index,
             include_masks=True)

        self.assertEqual(num_annotations_skipped, 0)
        self._assertProtoEqual(
            example.features.feature['image/height'].int64_list.value, [8])
        self._assertProtoEqual(
            example.features.feature['image/width'].int64_list.value, [8])
        self._assertProtoEqual(
            example.features.feature['image/filename'].bytes_list.value,
            [six.b(image_file_name)])
        self._assertProtoEqual(
            example.features.feature['image/source_id'].bytes_list.value,
            [six.b(str(image['id']))])
        self._assertProtoEqual(
            example.features.feature['image/format'].bytes_list.value,
            [six.b('jpeg')])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmin'].float_list.
            value, [0])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymin'].float_list.
            value, [0])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmax'].float_list.
            value, [1])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymax'].float_list.
            value, [1])
        self._assertProtoEqual(
            example.features.feature['image/object/class/text'].bytes_list.
            value, [six.b('dog')])
        encoded_mask_pngs = [
            io.BytesIO(encoded_masks) for encoded_masks in
            example.features.feature['image/object/mask'].bytes_list.value
        ]
        pil_masks = [
            np.array(PIL.Image.open(encoded_mask_png))
            for encoded_mask_png in encoded_mask_pngs
        ]
        self.assertEqual(len(pil_masks), 1)
        self.assertAllEqual(
            pil_masks[0], [[1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 1, 1],
                           [0, 0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 1, 1, 1, 1]])