Ejemplo n.º 1
0
    def __init__(
        self,
        tfrecord,  # tfrecords file
        resolution=None,  # Dataset resolution, None = autodetect.
        label_file=None,  # Relative path of the labels file, None = autodetect.
        max_label_size=0,  # 0 = no labels, 'full' = full labels, <int> = N first label components.
        repeat=True,  # Repeat dataset indefinitely?
        shuffle_mb=4096,  # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb=2048,  # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=2):  # Number of concurrent threads.

        self.tfrecord = tfrecord
        self.resolution = None
        self.res_log2 = None
        self.shape = []  # [channels, height, width]
        self.dtype = 'uint8'
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = None  # components
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_dataset = None
        self._tf_iterator = None
        self._tf_init_op = None
        self._tf_minibatch_np = None
        self._cur_minibatch = -1

        # List tfrecords files and inspect their shapes.
        assert os.path.isfile(self.tfrecord)

        tfr_file = self.tfrecord
        data_dir = os.path.dirname(tfr_file)

        tfr_opt = tf.python_io.TFRecordOptions(
            tf.python_io.TFRecordCompressionType.NONE)
        for record in tf.python_io.tf_record_iterator(
                tfr_file, tfr_opt):  # in fact only one
            tfr_shape = self.parse_tfrecord_shape(record)  # [c,h,w]
            jpg_data = tfr_shape[0] < 4
            break

        # Autodetect label filename.
        if self.label_file is None:
            # guess = sorted(glob.glob(os.path.join(data_dir, '*.labels')))
            guess = [
                ff for ff in file_list(data_dir, 'labels') if basename(
                    ff).split('-')[0] == basename(tfrecord).split('-')[0]
            ]
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(data_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution
        self.shape = list(tfr_shape)
        max_res = calc_res(tfr_shape[1:])
        self.resolution = resolution if resolution is not None else max_res
        self.res_log2 = int(np.ceil(np.log2(self.resolution)))
        self.init_res = [
            int(s * 2**(2 - self.res_log2)) for s in self.shape[1:]
        ]

        # Load labels
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1 << 30, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name='minibatch_in',
                                                   shape=[])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(
                self._np_labels, name='labels_var')
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)

            dset = tf.data.TFRecordDataset(tfr_file,
                                           compression_type='',
                                           buffer_size=buffer_mb << 20)
            if jpg_data is True:
                dset = dset.map(self.parse_tfrecord_tf_jpg,
                                num_parallel_calls=num_threads)
            else:
                dset = dset.map(self.parse_tfrecord_tf,
                                num_parallel_calls=num_threads)
            dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))

            bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
            if shuffle_mb > 0:
                dset = dset.shuffle((
                    (shuffle_mb << 20) - 1) // bytes_per_item + 1)
            if repeat:
                dset = dset.repeat()
            if prefetch_mb > 0:
                dset = dset.prefetch((
                    (prefetch_mb << 20) - 1) // bytes_per_item + 1)
            dset = dset.batch(self._tf_minibatch_in)
            self._tf_dataset = dset

            # self._tf_iterator = tf.data.Iterator.from_structure(tf.data.get_output_types(self._tf_dataset), tf.data.get_output_shapes(self._tf_dataset),)
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_dataset.output_types, self._tf_dataset.output_shapes)
            self._tf_init_op = self._tf_iterator.make_initializer(
                self._tf_dataset)
Ejemplo n.º 2
0
 def resolution(self):
     assert len(self.image_shape) == 3  # CHW
     # !!! custom init res
     max_res = calc_res(self.image_shape[1:])
     return max_res