Exemple #1
0
 def configure(self, minibatch_size, lod=0):
     lod = int(np.floor(lod))
     assert minibatch_size >= 1 and lod >= 0 and lod <= self.resolution_log2
     tfutil.set_vars({
         self._tf_minibatch_var: minibatch_size,
         self._tf_lod_var: lod
     })
    def __init__(self,
        tfrecord_dir,               # Directory containing a collection of tfrecords files.
        resolution      = None,     # Dataset resolution, None = autodetect.
        label_file      = None,     # Relative path of the labels file, None = autodetect.
        max_label_size  = 0,        # 0 = no labels, 'full' = full labels, <int> = N first label components.
        repeat          = True,     # Repeat dataset indefinitely.
        shuffle_mb      = 4096,     # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb     = 2048,     # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb       = 256,      # Read buffer size (megabytes).
        num_threads     = 2):       # Number of concurrent threads.

        self.tfrecord_dir       = tfrecord_dir
        self.resolution         = None
        self.resolution_log2    = None
        self.shape              = []        # [channel, height, width]
        self.dtype              = 'uint8'
        self.dynamic_range      = [0, 255]
        self.label_file         = label_file
        self.label_size         = None      # [component]
        self.label_dtype        = None
        self._np_labels         = None
        self._tf_minibatch_in   = None
        self._tf_labels_var     = None
        self._tf_labels_dataset = None
        self._tf_datasets       = dict()
        self._tf_iterator       = None
        self._tf_init_ops       = dict()
        self._tf_minibatch_np   = None
        self._cur_minibatch     = -1
        self._cur_lod           = -1

        # List tfrecords files and inspect their shapes.
        assert os.path.isdir(self.tfrecord_dir)
        tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))
        assert len(tfr_files) >= 1
        tfr_shapes = []
        for tfr_file in tfr_files:
            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
            for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
                tfr_shapes.append(parse_tfrecord_np(record).shape)
                break

        # Autodetect label filename.
        if self.label_file is None:
            guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution.
        max_shape = max(tfr_shapes, key=lambda shape: np.prod(shape))
        self.resolution = resolution if resolution is not None else max_shape[1]
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_shape[0], self.resolution, self.resolution]
        tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_shapes)
        assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))
        assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))

        # Load labels.
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1<<20, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])
            tf_labels_init = tf.zeros(self._np_labels.shape, self._np_labels.dtype)
            self._tf_labels_var = tf.Variable(tf_labels_init, name='labels_var')
            tfutil.set_vars({self._tf_labels_var: self._np_labels})
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)
            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):
                if tfr_lod < 0:
                    continue
                dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
                dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
                bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
                if shuffle_mb > 0:
                    dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
                if repeat:
                    dset = dset.repeat()
                if prefetch_mb > 0:
                    dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
                dset = dset.batch(self._tf_minibatch_in)
                self._tf_datasets[tfr_lod] = dset
            self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}
Exemple #3
0
    def __init__(
        self,
        tfrecord_dir,  # Directory containing a collection of tfrecords files.
        wsi_ext='.svs',  # extension of WSIs to look for in the dataset folder
        resolution=512,  # Dataset patch size (resolution)
        downsample=4,  # starting downsample of patches from WSI
        label_file=None,  # Relative path of the labels file, None = autodetect.
        max_label_size=0,  # 0 = no labels, 'full' = full labels, <int> = N first label components.
        repeat=True,  # Repeat dataset indefinitely.
        shuffle_mb=4096,  # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb=2048,  # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=30):  # Number of concurrent threads.

        self.tfrecord_dir = tfrecord_dir
        self.wsi_ext = wsi_ext
        self.resolution = resolution
        self.downsample = downsample
        self.resolution_log2 = None
        self.shape = []  # [channel, height, width]
        self.dtype = 'uint8'
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = None  # [component]
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_datasets = dict()
        self._tf_iterator = None
        self._tf_init_ops = dict()
        self._tf_minibatch_np = None
        self._cur_minibatch = -1
        self._cur_lod = -1

        # List tfrecords files and inspect their shapes.
        assert os.path.isdir(self.tfrecord_dir)
        wsi_files = sorted(
            glob.glob(
                os.path.join(self.tfrecord_dir, '*{}'.format(self.wsi_ext))))
        assert len(wsi_files) >= 1
        print('Found {} WSI files'.format(len(wsi_files)))

        print('Creating WSI Tissue Masks....')
        for file_path in wsi_files:
            save_wsi_thumbnail_mask(file_path)

        # Determine shape and resolution.
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [3, self.resolution, self.resolution]
        tfr_lods = [
            self.resolution_log2 - i
            for i in range(2, self.resolution_log2 + 1)
        ]
        tfr_shapes = []
        for lod in tfr_lods:
            shape = self.resolution // (2**lod)
            tfr_shapes.append((self.shape[0], shape, shape))

        # Load labels.
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1 << 20, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        print('Setting Up Dataset...')
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name='minibatch_in',
                                                   shape=[])

            tf_labels_init = tf.zeros(self._np_labels.shape,
                                      self._np_labels.dtype)
            self._tf_labels_var = tf.Variable(tf_labels_init,
                                              name='labels_var')
            tfutil.set_vars({self._tf_labels_var: self._np_labels})
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)

            for tfr_shape, tfr_lod in zip(tfr_shapes, tfr_lods):
                if tfr_lod < 0:
                    continue

                output_size = self.resolution // (2**tfr_lod)
                this_downsample = (2**tfr_lod) * self.downsample

                print('lod: {} | patch_size: {} | downsample: {}'.format(
                    tfr_lod, output_size, this_downsample))

                path_ds = tf.data.Dataset.from_tensor_slices(wsi_files)
                path_ds = path_ds.shuffle(100)
                dset = path_ds.map(
                    lambda filename: tf.py_function(get_random_wsi_batch, [
                        filename, output_size, False, this_downsample
                    ], tf.float32),
                    num_parallel_calls=num_threads)

                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
                bytes_per_item = np.prod(tfr_shape) * np.dtype(
                    self.dtype).itemsize
                dset = dset.shuffle(100)
                if repeat:
                    dset = dset.repeat()
                dset = dset.batch(self._tf_minibatch_in)
                dset = dset.prefetch(buffer_size=2)
                self._tf_datasets[tfr_lod] = dset
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_datasets[0].output_types,
                self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {
                lod: self._tf_iterator.make_initializer(dset)
                for lod, dset in self._tf_datasets.items()
            }
    def __init__(
        self,
        tfrecord_dir,  # Directory containing a collection of tfrecords files.
        resolution=None,  # Dataset resolution, None = autodetect.
        label_file=None,  # Relative path of the labels file, None = autodetect.
        labeltypes=None,  # can include: 0 for 'channelorientation', 1 for 'mudproportion', 2 for 'channelwidth', 3 for 'channelsinuosity'
        repeat=True,  # Repeat dataset indefinitely.
        shuffle_mb=4096,  # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb=2048,  # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=2,  # Number of concurrent threads.
        well_enlarge=False):  # If enlarged well points are outputted

        self.tfrecord_dir = tfrecord_dir
        self.resolution = None
        self.resolution_log2 = None
        self.shape = []  # [channel, height, width]
        self.dtype = 'uint8'
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = 0  # [component]
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_probimages_dataset = None
        self._tf_wellfacies_dataset = None
        self._tf_datasets = dict()
        self._tf_iterator = None
        self._tf_init_ops = dict()
        self._tf_minibatch_np = None
        self._cur_minibatch = -1
        self._cur_lod = -1
        self.well_enlarge = well_enlarge

        # List realimage tfrecords files and inspect their shapes.
        assert os.path.isdir(self.tfrecord_dir)
        tfr_files = sorted(
            glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))
        tfr_realimage_files = tfr_files[:
                                        -2]  #as tfrecord files include 02-06 real image files, one prob_image file and one well_facies file, [:-2] ensures only reale image files are selected.
        assert len(tfr_realimage_files) >= 1
        tfr_realimage_shapes = []  # tfr_realimage_shapes
        for tfr_realimage_file in tfr_realimage_files:  #
            tfr_realimage_opt = tf.python_io.TFRecordOptions(
                tf.python_io.TFRecordCompressionType.NONE)
            for record in tf.python_io.tf_record_iterator(
                    tfr_realimage_file, tfr_realimage_opt):
                tfr_realimage_shapes.append(parse_tfrecord_np(record).shape)
                break

        # List probimage tfrecord files and inspect its shape.
        tfr_probimage_file = tfr_files[
            -2]  #as tfrecord files include 02-06 real image files, one prob_image file and one well_facies file, [-2] ensures only prob image file are selected.
        tfr_probimage_opt = tf.python_io.TFRecordOptions(
            tf.python_io.TFRecordCompressionType.NONE)
        for record in tf.python_io.tf_record_iterator(tfr_probimage_file,
                                                      tfr_probimage_opt):
            tfr_probimage_shape = parse_tfrecord_np_float16(record).shape
            break

        # List well facies tfrecord files and inspect its shape.
        tfr_wellfacies_file = tfr_files[
            -1]  #as tfrecord files include 02-06 real image files, one prob_image file and one well_facies file, [-1] ensures only well facies file are selected.
        tfr_wellfacies_opt = tf.python_io.TFRecordOptions(
            tf.python_io.TFRecordCompressionType.NONE)
        for record in tf.python_io.tf_record_iterator(tfr_wellfacies_file,
                                                      tfr_wellfacies_opt):
            tfr_wellfacies_shape = parse_tfrecord_np(
                record
            ).shape  # well facies only contain code of 0 (no well facies), 1 (mud well facies), 2 (levee), 3 (channel)
            break

        # Autodetect label filename.
        if self.label_file is None:
            guess = sorted(
                glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution of realimage. some parameters are marked with _realimage_, but some are not. All probimage related parameters are marked with _probimage_.
        max_realimage_shape = max(tfr_realimage_shapes,
                                  key=lambda shape: np.prod(shape))
        self.resolution = resolution if resolution is not None else max_realimage_shape[
            1]
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_realimage_shape[0], self.resolution, self.resolution]
        tfr_realimage_lods = [
            self.resolution_log2 - int(np.log2(shape[1]))
            for shape in tfr_realimage_shapes
        ]
        assert all(shape[0] == max_realimage_shape[0]
                   for shape in tfr_realimage_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_realimage_shapes)
        assert all(
            shape[1] == self.resolution // (2**lod)
            for shape, lod in zip(tfr_realimage_shapes, tfr_realimage_lods))
        assert all(lod in tfr_realimage_lods
                   for lod in range(self.resolution_log2 - 1))

        # Load labels.
        self.label_size = len(labeltypes)
        assert self.label_size >= 0
        self._np_labels = np.zeros([1 << 17, 0], dtype=np.float32)
        if self.label_size > 0:
            self._np_labels = np.load(self.label_file)[:, labeltypes]
        if self.label_size == 0:
            self._np_labels = np.load(self.label_file)[:, :0]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name='minibatch_in',
                                                   shape=[])
            tf_labels_init = tf.zeros(self._np_labels.shape,
                                      self._np_labels.dtype)
            self._tf_labels_var = tf.Variable(tf_labels_init,
                                              name='labels_var')
            tfutil.set_vars({self._tf_labels_var: self._np_labels})
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)

            for tfr_realimage_file, tfr_realimage_shape, tfr_realimage_lod in zip(
                    tfr_realimage_files, tfr_realimage_shapes,
                    tfr_realimage_lods):
                if tfr_realimage_lod < 0:
                    continue
                dset = tf.data.TFRecordDataset(tfr_realimage_file,
                                               compression_type='',
                                               buffer_size=buffer_mb << 17)
                dset = dset.map(parse_tfrecord_tf,
                                num_parallel_calls=num_threads)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
                bytes_per_item = np.prod(tfr_realimage_shape) * np.dtype(
                    self.dtype).itemsize
                if shuffle_mb > 0:
                    dset = dset.shuffle((
                        (shuffle_mb << 17) - 1) // bytes_per_item + 1)
                if repeat:
                    dset = dset.repeat()
                if prefetch_mb > 0:
                    dset = dset.prefetch((
                        (prefetch_mb << 17) - 1) // bytes_per_item + 1)
                dset = dset.batch(self._tf_minibatch_in)
                self._tf_datasets[tfr_realimage_lod] = dset
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_datasets[0].output_types,
                self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {
                lod: self._tf_iterator.make_initializer(dset)
                for lod, dset in self._tf_datasets.items()
            }

            tf_probimages_dset = tf.data.TFRecordDataset(
                tfr_probimage_file,
                compression_type='',
                buffer_size=buffer_mb << 17)
            tf_probimages_dset = tf_probimages_dset.map(
                parse_tfrecord_tf_float16, num_parallel_calls=num_threads)
            self._tf_probimages_dataset = tf_probimages_dset
            tf_wellfacies_dset = tf.data.TFRecordDataset(
                tfr_wellfacies_file,
                compression_type='',
                buffer_size=buffer_mb << 17)
            tf_wellfacies_dset = tf_wellfacies_dset.map(
                parse_tfrecord_tf, num_parallel_calls=num_threads)
            self._tf_wellfacies_dataset = tf_wellfacies_dset
            self._tf_probimages_wellfacies_dset = tf.data.Dataset.zip(
                (self._tf_probimages_dataset, self._tf_wellfacies_dataset))
            if shuffle_mb > 0:
                self._tf_probimages_wellfacies_dset = self._tf_probimages_wellfacies_dset.shuffle(
                    ((shuffle_mb << 17) - 1) // bytes_per_item + 1)
            if repeat:
                self._tf_probimages_wellfacies_dset = self._tf_probimages_wellfacies_dset.repeat(
                )
            if prefetch_mb > 0:
                self._tf_probimages_wellfacies_dset = self._tf_probimages_wellfacies_dset.prefetch(
                    ((prefetch_mb << 17) - 1) // bytes_per_item + 1)
            self._tf_probimages_wellfacies_dset = self._tf_probimages_wellfacies_dset.batch(
                self._tf_minibatch_in)
            self._tf_probimages_wellfacies_iterator = tf.data.Iterator.from_structure(
                self._tf_probimages_wellfacies_dset.output_types,
                self._tf_probimages_wellfacies_dset.output_shapes)
            self._tf_probimages_wellfacies_init_ops = self._tf_probimages_wellfacies_iterator.make_initializer(
                self._tf_probimages_wellfacies_dset)
Exemple #5
0
    def __init__(
        self,
        tfrecord_file,  # Tfrecords files.
        resolution=None,  # Dataset resolution, None = autodetect.
        label_file=None,  # Relative path of the labels file, None = autodetect.
        max_label_size=0,  # 0 = no labels, 'full' = full labels, <int> = N first label components.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=2):  # Number of concurrent threads.

        self.tfrecord_file = tfrecord_file
        self.resolution = None
        self.resolution_log2 = None
        self.shape = []  # [channel, height, width]
        self.dtype = 'uint8'
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = None  # [component]
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_datasets = dict()
        self._tf_iterator = None
        self._tf_init_ops = dict()
        self._tf_minibatch_np = None
        self._cur_minibatch = -1
        self._cur_lod = -1

        assert os.path.isfile(self.tfrecord_file)
        assert os.path.isfile(self.label_file)
        tfr_shapes = []
        tfr_opt = tf.python_io.TFRecordOptions(
            tf.python_io.TFRecordCompressionType.NONE)
        for record in tf.python_io.tf_record_iterator(self.tfrecord_file,
                                                      tfr_opt):
            tfr_shapes.append(parse_tfrecord_np(record).shape)
            break

        tfr_files = [self.tfrecord_file]

        # Determine shape and resolution.
        max_shape = max(tfr_shapes, key=lambda shape: np.prod(shape))
        self.resolution = resolution if resolution is not None else max_shape[1]
        print("----------------------------", flush=True)
        print("Resolution used:", self.resolution, flush=True)
        print("----------------------------", flush=True)
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_shape[0], self.resolution, self.resolution]
        tfr_lods = [
            self.resolution_log2 - int(np.log2(shape[1]))
            for shape in tfr_shapes
        ]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_shapes)
        assert all(shape[1] == self.resolution // (2**lod)
                   for shape, lod in zip(tfr_shapes, tfr_lods))

        print("tfr_lods:", tfr_lods, flush=True)
        print("tfr_shapes:", tfr_shapes, flush=True)
        print("tfr_files:", tfr_files, flush=True)

        #assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))

        # Load labels.
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1 << 20, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name='minibatch_in',
                                                   shape=[])
            tf_labels_init = tf.zeros(self._np_labels.shape,
                                      self._np_labels.dtype)
            self._tf_labels_var = tf.Variable(tf_labels_init,
                                              name='labels_var')
            tfutil.set_vars({self._tf_labels_var: self._np_labels})
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)
            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes,
                                                    tfr_lods):
                if tfr_lod < 0:
                    continue
                dset = tf.data.TFRecordDataset(tfr_file,
                                               compression_type='',
                                               buffer_size=buffer_mb << 20)
                dset = dset.map(parse_tfrecord_tf,
                                num_parallel_calls=num_threads)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
                bytes_per_item = np.prod(tfr_shape) * np.dtype(
                    self.dtype).itemsize
                dset = dset.batch(self._tf_minibatch_in)
                self._tf_datasets[tfr_lod] = dset
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_datasets[0].output_types,
                self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {
                lod: self._tf_iterator.make_initializer(dset)
                for lod, dset in self._tf_datasets.items()
            }
    def __init__(self,
        tfrecord_dir,               # Directory containing a collection of tfrecords files.
        resolution      = None,     # Dataset resolution, None = autodetect.
        label_file      = None,     # Relative path of the labels file, None = autodetect.
        max_label_size  = 0,        # 0 = no labels, 'full' = full labels, <int> = N first label components.
        repeat          = True,     # Repeat dataset indefinitely.
        shuffle_mb      = 4096,     # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb     = 2048,     # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb       = 256,      # Read buffer size (megabytes).
        num_threads     = 2):       # Number of concurrent threads.

        self.tfrecord_dir       = tfrecord_dir
        self.resolution         = None
        self.resolution_log2    = None
        self.shape              = []        # [channel, height, width]
        self.dtype              = 'uint8'
        self.dynamic_range      = [0, 255]
        self.label_file         = label_file
        self.label_size         = None      # [component]
        self.label_dtype        = None
        self._np_labels         = None
        self._tf_minibatch_in   = None
        self._tf_labels_var     = None
        self._tf_labels_dataset = None
        self._tf_datasets       = dict()
        self._tf_iterator       = None
        self._tf_init_ops       = dict()
        self._tf_minibatch_np   = None
        self._cur_minibatch     = -1
        self._cur_lod           = -1

        # List tfrecords files and inspect their shapes.
        assert os.path.isdir(self.tfrecord_dir)
        tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))
        assert len(tfr_files) >= 1
        tfr_shapes = []
        for tfr_file in tfr_files:
            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
            for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
                tfr_shapes.append(parse_tfrecord_np(record).shape)
                break

        # Autodetect label filename.
        if self.label_file is None:
            guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution.
        max_shape = max(tfr_shapes, key=lambda shape: np.prod(shape))
        self.resolution = resolution if resolution is not None else max_shape[1]
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_shape[0], self.resolution, self.resolution]
        tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_shapes)
        assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))
        assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))

        # Load labels.
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1<<20, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])
            tf_labels_init = tf.zeros(self._np_labels.shape, self._np_labels.dtype)
            self._tf_labels_var = tf.Variable(tf_labels_init, name='labels_var')
            tfutil.set_vars({self._tf_labels_var: self._np_labels})
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)
            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):
                if tfr_lod < 0:
                    continue
                dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
                dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
                bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
                if shuffle_mb > 0:
                    dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
                if repeat:
                    dset = dset.repeat()
                if prefetch_mb > 0:
                    dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
                dset = dset.batch(self._tf_minibatch_in)
                self._tf_datasets[tfr_lod] = dset
            self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}
 def configure(self, minibatch_size, lod=0):
     lod = int(np.floor(lod))
     assert minibatch_size >= 1 and lod >= 0 and lod <= self.resolution_log2
     tfutil.set_vars({self._tf_minibatch_var: minibatch_size, self._tf_lod_var: lod})