Пример #1
0
    def _convert_dataset(self, dataset_split):
        """Converts the specified dataset split to TFRecord format.

        Args:
          dataset_split: The dataset split (e.g., train, test).

        Raises:
          RuntimeError: If loaded image and label have different shape.
        """
        dataset = os.path.basename(dataset_split)[:-4]
        sys.stdout.write('Processing ' + dataset)
        filenames = [x.strip('\n') for x in open(dataset_split, 'r')]
        num_images = len(filenames)
        num_per_shard = int(math.ceil(num_images / _NUM_SHARDS))

        image_reader = build_data.ImageReader('jpeg', channels=3)
        label_reader = build_data.ImageReader('png', channels=1)

        for shard_id in range(_NUM_SHARDS):
            output_filename = os.path.join(
                self.output_dir,
                '%s-%05d-of-%05d.tfrecord' % (dataset, shard_id, _NUM_SHARDS))
            with tf.python_io.TFRecordWriter(
                    output_filename) as tfrecord_writer:
                start_idx = shard_id * num_per_shard
                end_idx = min((shard_id + 1) * num_per_shard, num_images)
                for i in range(start_idx, end_idx):
                    sys.stdout.write('\r>> Converting image %d/%d shard %d' %
                                     (i + 1, len(filenames), shard_id))
                    sys.stdout.flush()
                    # Read the image.
                    image_filename = os.path.join(
                        self.image_folder,
                        filenames[i] + '.' + self.image_format)
                    image_data = tf.gfile.GFile(image_filename, 'rb').read()
                    height, width = image_reader.read_image_dims(image_data)
                    # Read the semantic segmentation annotation.
                    seg_filename = os.path.join(
                        self.semantic_segmentation_folder,
                        filenames[i] + '.' + self.label_format)
                    seg_data = tf.gfile.GFile(seg_filename, 'rb').read()
                    seg_height, seg_width = label_reader.read_image_dims(
                        seg_data)
                    if height != seg_height or width != seg_width:
                        raise RuntimeError(
                            'Shape mismatched between image and label.')
                    # Convert to tf example.
                    example = build_data.image_seg_to_tfexample(
                        image_data, filenames[i], height, width, seg_data)
                    tfrecord_writer.write(example.SerializeToString())
            sys.stdout.write('\n')
            sys.stdout.flush()
Пример #2
0
def _convert_dataset(tfrec_name, dataset_dir, dataset_label_dir):
    img_names = tf.gfile.Glob(os.path.join(dataset_dir,
                                           f'*.{PARAM.image_ext}'))
    random.shuffle(img_names)
    seg_names = []

    for f in img_names:
        basename = os.path.basename(f).split('.')[0]
        seg = os.path.join(dataset_label_dir, f'{basename}.{PARAM.label_ext}')
        seg_names.append(seg)

    num_images = len(img_names)
    num_per_shard = int(math.ceil(num_images / float(PARAM.num_shards)))

    image_reader = build_data.ImageReader(
        'png' if PARAM.image_ext == 'png' else 'jpeg',
        channels=PARAM.image_nchannels)
    label_reader = build_data.ImageReader(
        'png' if PARAM.image_ext == 'png' else 'jpeg',
        channels=PARAM.label_nchannels)

    for shard_id in range(PARAM.num_shards):
        output_filename = os.path.join(
            PARAM.output_dir, '%s-%05d-of-%05d.tfrecord' %
            (tfrec_name, shard_id, PARAM.num_shards))
        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
            start_idx = shard_id * num_per_shard
            end_idx = min((shard_id + 1) * num_per_shard, num_images)
            for i in range(start_idx, end_idx):
                print(
                    f'\r>> Converting image {i+1}/{num_images} shard {shard_id}'
                )
                # Read the image.
                image_filename = img_names[i]
                image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
                height, width = image_reader.read_image_dims(image_data)
                # Read the semantic segmentation annotation.
                seg_filename = seg_names[i]
                seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
                seg_height, seg_width = label_reader.read_image_dims(seg_data)
                if height != seg_height or width != seg_width:
                    raise RuntimeError(
                        'Shape mismatched between image and label.')
                # Convert to tf example.
                example = build_data.image_seg_to_tfexample(
                    image_data, img_names[i], height, width, seg_data)
                tfrecord_writer.write(example.SerializeToString())
Пример #3
0
def create_deeplab_tfrecords(input_folder, tfrecord_file):
    """Creates a tfrecord file for a given folder

    Parameters:
        input_folder: str, path to samples for a given dataset
        tfrecord_file: str, path to tfrecord that will be created

    Flags:
        See docstring for more information
        color_input: whether to use gray or color images
        multi_class: binary or multi-class segmentation
        location_gradients: location information as extra channels
    """
    label_paths = helper_scripts.get_files_from_folder(input_folder, '.json')
    shuffle(label_paths)
    print('{} label files in {}'.format(len(label_paths), input_folder))

    loc_grad_x = list(
        map(lambda z: z / constants.IMAGE_WIDTH * 255,
            range(constants.IMAGE_WIDTH)))
    loc_grad_y = list(
        map(lambda z: z / constants.IMAGE_HEIGHT * 255,
            range(constants.IMAGE_HEIGHT)))
    loc_grad_x = numpy.asarray([loc_grad_x] * constants.IMAGE_HEIGHT)
    loc_grad_y = numpy.asarray([loc_grad_y] *
                               constants.IMAGE_WIDTH).transpose()
    loc_grad_x = numpy.round(loc_grad_x).astype(numpy.uint8)
    loc_grad_y = numpy.round(loc_grad_y).astype(numpy.uint8)

    os.makedirs(os.path.dirname(tfrecord_file), exist_ok=True)
    with tf.python_io.TFRecordWriter(tfrecord_file) as writer:
        for label_path in tqdm.tqdm(label_paths,
                                    total=len(label_paths),
                                    desc='Creating ' + tfrecord_file):

            image_name = os.path.basename(label_path).replace('.json', '')
            if FLAGS.color_input:
                image_data = label_file_scripts.read_image(label_path,
                                                           image_type='color')
            else:
                image_data = label_file_scripts.read_image(label_path,
                                                           image_type='gray')
                if FLAGS.location_gradients:
                    image_data = numpy.stack(
                        [image_data, loc_grad_x, loc_grad_y], -1)
            image_data = cv2.imencode('.png', image_data)[1].tostring()

            if FLAGS.multi_class:
                segmentation_label = segmentation_labels.create_multi_class_segmentation_label(
                    label_path)
                segmentation = numpy.zeros(segmentation_label.shape[0:2],
                                           numpy.uint8)
                for class_index in range(1, 5):
                    segmentation[segmentation_label[:, :, class_index] >
                                 0] = class_index
            else:
                segmentation = visualize_labels.create_segmentation_image(
                    label_path, image='blank')
                segmentation = cv2.cvtColor(segmentation, cv2.COLOR_BGR2GRAY)
                segmentation = segmentation > 0
                segmentation = segmentation.astype(numpy.uint8)

            segmentation = cv2.imencode('.png', segmentation)[1].tostring()

            example = build_data.image_seg_to_tfexample(
                image_data, image_name, constants.IMAGE_HEIGHT,
                constants.IMAGE_WIDTH, segmentation)

            writer.write(example.SerializeToString())