示例#1
0
    def __init__(
        self,
        tfrecord_dir,  # Directory containing a collection of tfrecords files.
        res_log2=7,
        min_h=4,
        min_w=4,
        resolution=None,  # Dataset resolution, None = autodetect.
        label_file=None,  # Relative path of the labels file, None = autodetect.
        max_label_size=0,  # 0 = no labels, 'full' = full labels, <int> = N first label components.
        repeat=True,  # Repeat dataset indefinitely.
        shuffle_mb=4096,  # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb=2048,  # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=2,
    ):  # Number of concurrent threads.

        self.tfrecord_dir = tfrecord_dir
        #self.res_log2 = res_log2
        #self.resolution = None
        #self.resolution_log2 = None
        self.shape = []  # [channel, height, width]
        self.dtype = "uint8"
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = None  # [component]
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_datasets = dict()
        self._tf_iterator = None
        self._tf_init_ops = dict()
        self._tf_minibatch_np = None
        self._cur_minibatch = -1
        self.min_h = min_h
        self.min_w = min_w
        # List tfrecords files and inspect their shapes.
        assert os.path.isdir(self.tfrecord_dir)
        tfr_files = sorted(
            glob.glob(os.path.join(self.tfrecord_dir, "*.tfrecords")))
        assert len(tfr_files) >= 1
        tfr_shapes = []
        for tfr_file in tfr_files:
            tfr_opt = tf.python_io.TFRecordOptions(
                tf.python_io.TFRecordCompressionType.NONE)
            for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
                #tfr_shapes.append(parse_tfrecord_np(record).shape)
                tfr_shapes.append(parse_tfrecord_np_aydao(record))
                break

        # Autodetect label filename.
        if self.label_file is None:
            guess = sorted(
                glob.glob(os.path.join(self.tfrecord_dir, "*.labels")))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution.
        max_shape = max(tfr_shapes, key=np.prod)
        #self.resolution = resolution if resolution is not None else max_shape[1]
        #self.resolution_log2 = int(np.log2(self.resolution))
        #self.shape = [max_shape[0], self.resolution, self.resolution]
        self.shape = [max_shape[0], max_shape[1], max_shape[2]]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        #assert all(shape[1] == shape[2] for shape in tfr_shapes)

        # Load labels.
        assert max_label_size == "full" or max_label_size >= 0
        self._np_labels = np.zeros([1 << 20, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != "full" and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope("Dataset"), tf.device("/cpu:0"):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name="minibatch_in",
                                                   shape=[])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(
                self._np_labels, name="labels_var")
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)

            # only use the full resolution version file tfr_file
            # This portion of the code has been edited to make sure that the Dataset
            # only returns the highest resolution files
            tfr_file = tfr_files[
                -1]  # should be the highest resolution tf_record file
            tfr_shape = tfr_shapes[-1]  # again the highest resolution shape
            dset = tf.data.TFRecordDataset(tfr_file,
                                           compression_type="",
                                           buffer_size=buffer_mb << 20)
            dset = dset.map(parse_tfrecord_tf_aydao,
                            num_parallel_calls=num_threads)
            dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
            bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
            if shuffle_mb > 0:
                dset = dset.shuffle((
                    (shuffle_mb << 20) - 1) // bytes_per_item + 1)
            if repeat:
                dset = dset.repeat()
            if prefetch_mb > 0:
                dset = dset.prefetch((
                    (prefetch_mb << 20) - 1) // bytes_per_item + 1)
            dset = dset.batch(self._tf_minibatch_in)
            self._tf_dataset = dset
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_dataset.output_types, self._tf_dataset.output_shapes)
            self._tf_init_op = self._tf_iterator.make_initializer(
                self._tf_dataset)
示例#2
0
    def __init__(self,
        tfrecord_dir,               # Directory containing a collection of tfrecords files.
        resolution      = None,     # Dataset resolution, None = autodetect.
        label_file      = None,     # Relative path of the labels file, None = autodetect.
        max_label_size  = 0,        # 0 = no labels, 'full' = full labels, <int> = N first label components.
        repeat          = True,     # Repeat dataset indefinitely.
        shuffle_mb      = 4096,     # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb     = 2048,     # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb       = 256,      # Read buffer size (megabytes).
        num_threads     = 2):       # Number of concurrent threads.

        self.tfrecord_dir       = tfrecord_dir
        self.resolution         = None
        self.resolution_log2    = None
        self.shape              = []        # [channel, height, width]
        self.dtype              = 'uint8'
        self.dynamic_range      = [0, 255]
        self.label_file         = label_file
        self.label_size         = None      # [component]
        self.label_dtype        = None
        self._np_labels         = None
        self._tf_minibatch_in   = None
        self._tf_labels_var     = None
        self._tf_labels_dataset = None
        self._tf_datasets       = dict()
        self._tf_iterator       = None
        self._tf_init_ops       = dict()
        self._tf_minibatch_np   = None
        self._cur_minibatch     = -1
        self._cur_lod           = -1

        # List tfrecords files and inspect their shapes.
        print("self.tfrecord_dir:",self.tfrecord_dir)
        assert os.path.isdir(self.tfrecord_dir)
        tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))
        assert len(tfr_files) >= 1
        tfr_shapes = []
        for tfr_file in tfr_files:
            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
            for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
                tfr_shapes.append(parse_tfrecord_np(record).shape)
                break

        # Autodetect label filename.
        if self.label_file is None:
            guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution.
        max_shape = max(tfr_shapes, key=np.prod)
        self.resolution = resolution if resolution is not None else max_shape[1]
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_shape[0], self.resolution, self.resolution]
        tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_shapes)
        assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))
        assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))

        # Load labels.
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1<<20, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var')
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)
            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):
                if tfr_lod < 0:
                    continue
                dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
                dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
                bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
                if shuffle_mb > 0:
                    dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
                if repeat:
                    dset = dset.repeat()
                if prefetch_mb > 0:
                    dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
                dset = dset.batch(self._tf_minibatch_in)
                self._tf_datasets[tfr_lod] = dset
            self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}
示例#3
0
    def __init__(
        self,
        tfrecord_dir,  # Directory containing a collection of tfrecords files
        resolution=None,  # Dataset resolution, None = autodetect
        label_file=None,  # Relative path of the labels file, None = autodetect
        max_label_size=0,  # 0 = no labels, "full" = full labels, <int> = N first label components
        max_imgs=None,  # Maximum number of images to use, None = use all images
        ratio=1.0,
        repeat=True,  # Repeat dataset indefinitely?
        shuffle_mb=2048,  # Shuffle data within specified window (megabytes), 0 = disable shuffling
        prefetch_mb=512,  # Amount of data to prefetch (megabytes), 0 = disable prefetching
        buffer_mb=256,  # Read buffer size (megabytes)
        num_threads=4):  # Number of concurrent threads for input processing

        self.tfrecord_dir = tfrecord_dir
        self.resolution = None
        self.resolution_log2 = None
        self.shape = []  # [channels, height, width]
        self.dtype = "uint8"
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = None
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_datasets = dict()
        self._tf_iterator = None
        self._tf_init_ops = dict()
        self._tf_minibatch_np = None
        self._cur_minibatch = -1
        self._cur_lod = -1

        # List tfrecords files and inspect their shapes
        assert os.path.isdir(self.tfrecord_dir)
        tfr_file_lists = sorted(
            glob.glob(os.path.join(self.tfrecord_dir, "*.tfrecords1of*")))
        # If max_imgs is not None, take a subset of images out of the 1st file. Otherwise take all files.
        if max_imgs is None:
            tfr_file_lists = [
                sorted(glob.glob(re.sub("1of.*", "*", f)))
                for f in tfr_file_lists
            ]

        assert len(tfr_file_lists) >= 1
        tfr_shapes = []
        for tfr_files in tfr_file_lists:
            tfr_opt = tf.python_io.TFRecordOptions(
                tf.python_io.TFRecordCompressionType.NONE)
            for record in tf.python_io.tf_record_iterator(
                    tfr_files[0], tfr_opt):
                tfr_shapes.append(self.parse_tfrecord_np(record).shape)
                break
            random.shuffle(tfr_files)

        # Autodetect label filename
        if self.label_file is None:
            guess = sorted(
                glob.glob(os.path.join(self.tfrecord_dir, "*.labels")))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution
        max_shape = max(tfr_shapes, key=np.prod)
        self.resolution = resolution if resolution is not None else max_shape[1]
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_shape[0], self.resolution, self.resolution]
        tfr_lods = [
            self.resolution_log2 - int(np.log2(shape[1]))
            for shape in tfr_shapes
        ]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_shapes)
        assert all(shape[1] == self.resolution // (2**lod)
                   for shape, lod in zip(tfr_shapes, tfr_lods))
        assert all(lod in range(self.resolution_log2 - 1) for lod in tfr_lods)

        # Load labels
        assert max_label_size == "full" or max_label_size >= 0
        self._np_labels = np.zeros([1 << 20, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != "full" and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        if max_imgs is not None and self._np_labels.shape[0] > max_imgs:
            self._np_labels = self._np_labels[:max_imgs]
        if max_imgs is not None and self._np_labels.shape[0] < max_imgs:
            print("Too many images. increase number.")
            exit()
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions
        with tf.name_scope("Dataset"), tf.device("/cpu:0"):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name="minibatch_in",
                                                   shape=[])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(
                self._np_labels, name="labels_var")
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)
            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes,
                                                    tfr_lods):
                if tfr_lod < 0:
                    continue
                # Load dataset
                dset = tf.data.TFRecordDataset(tfr_file,
                                               compression_type="",
                                               buffer_size=buffer_mb << 20,
                                               num_parallel_reads=None)

                # If max_imgs is set, take a subset of the data
                if max_imgs is not None:
                    dset = dset.take(max_imgs)

                # Parse the TF records
                dset = dset.map(self.parse_tfrecord_tf,
                                num_parallel_calls=num_threads)

                # Zip images with their labels (0s if no labels)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))

                # Shuffle and repeat
                bytes_per_item = np.prod(tfr_shape) * np.dtype(
                    self.dtype).itemsize
                if shuffle_mb > 0:
                    dset = dset.shuffle((
                        (shuffle_mb << 20) - 1) // bytes_per_item + 1)
                if repeat:
                    dset = dset.repeat()

                # Prefetch and batch
                if prefetch_mb > 0:
                    dset = dset.prefetch((
                        (prefetch_mb << 20) - 1) // bytes_per_item + 1)
                dset = dset.batch(self._tf_minibatch_in)
                self._tf_datasets[tfr_lod] = dset

            # Initialize data iterator
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_datasets[0].output_types,
                self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {
                lod: self._tf_iterator.make_initializer(dset)
                for lod, dset in self._tf_datasets.items()
            }
示例#4
0
    def __init__(
        self,
        tfrecord,  # tfrecords file
        resolution=None,  # Dataset resolution, None = autodetect.
        label_file=None,  # Relative path of the labels file, None = autodetect.
        max_label_size=0,  # 0 = no labels, 'full' = full labels, <int> = N first label components.
        repeat=True,  # Repeat dataset indefinitely?
        shuffle_mb=4096,  # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb=2048,  # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=2):  # Number of concurrent threads.

        self.tfrecord = tfrecord
        self.resolution = None
        self.res_log2 = None
        self.shape = []  # [channels, height, width]
        self.dtype = 'uint8'
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = None  # components
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_dataset = None
        self._tf_iterator = None
        self._tf_init_op = None
        self._tf_minibatch_np = None
        self._cur_minibatch = -1

        # List tfrecords files and inspect their shapes.
        assert os.path.isfile(self.tfrecord)

        tfr_file = self.tfrecord
        data_dir = os.path.dirname(tfr_file)

        tfr_opt = tf.python_io.TFRecordOptions(
            tf.python_io.TFRecordCompressionType.NONE)
        for record in tf.python_io.tf_record_iterator(
                tfr_file, tfr_opt):  # in fact only one
            tfr_shape = self.parse_tfrecord_shape(record)  # [c,h,w]
            jpg_data = tfr_shape[0] < 4
            break

        # Autodetect label filename.
        if self.label_file is None:
            # guess = sorted(glob.glob(os.path.join(data_dir, '*.labels')))
            guess = [
                ff for ff in file_list(data_dir, 'labels') if basename(
                    ff).split('-')[0] == basename(tfrecord).split('-')[0]
            ]
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(data_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution
        self.shape = list(tfr_shape)
        max_res = calc_res(tfr_shape[1:])
        self.resolution = resolution if resolution is not None else max_res
        self.res_log2 = int(np.ceil(np.log2(self.resolution)))
        self.init_res = [
            int(s * 2**(2 - self.res_log2)) for s in self.shape[1:]
        ]

        # Load labels
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1 << 30, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name='minibatch_in',
                                                   shape=[])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(
                self._np_labels, name='labels_var')
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)

            dset = tf.data.TFRecordDataset(tfr_file,
                                           compression_type='',
                                           buffer_size=buffer_mb << 20)
            if jpg_data is True:
                dset = dset.map(self.parse_tfrecord_tf_jpg,
                                num_parallel_calls=num_threads)
            else:
                dset = dset.map(self.parse_tfrecord_tf,
                                num_parallel_calls=num_threads)
            dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))

            bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
            if shuffle_mb > 0:
                dset = dset.shuffle((
                    (shuffle_mb << 20) - 1) // bytes_per_item + 1)
            if repeat:
                dset = dset.repeat()
            if prefetch_mb > 0:
                dset = dset.prefetch((
                    (prefetch_mb << 20) - 1) // bytes_per_item + 1)
            dset = dset.batch(self._tf_minibatch_in)
            self._tf_dataset = dset

            # self._tf_iterator = tf.data.Iterator.from_structure(tf.data.get_output_types(self._tf_dataset), tf.data.get_output_shapes(self._tf_dataset),)
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_dataset.output_types, self._tf_dataset.output_shapes)
            self._tf_init_op = self._tf_iterator.make_initializer(
                self._tf_dataset)
示例#5
0
    def __init__(
        self,
        tfrecord_dir,  # Directory containing a collection of tfrecords files.
        resolution=None,  # Dataset resolution, None = autodetect.
        label_file=None,  # Relative path of the labels file, None = autodetect.
        max_label_size=0,  # 0 = no labels, 'full' = full labels, <int> = N first label components.
        max_images=None,  # Maximum number of images to use, None = use all images.
        max_validation=10000,  # Maximum size of the validation set, None = use all available images.
        mirror_augment=False,  # Apply mirror augment?
        mirror_augment_v=False,  # Apply mirror augment vertically?
        repeat=True,  # Repeat dataset indefinitely?
        shuffle=True,  # Shuffle images?
        shuffle_mb=4096,  # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb=2048,  # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=2,  # Number of concurrent threads.
        _is_validation=False,
        use_raw=False,
    ):
        self.tfrecord_dir = tfrecord_dir
        self.resolution = None
        self.resolution_log2 = None
        self.shape = []  # [channels, height, width]
        self.dtype = 'uint8'
        self.label_file = label_file
        self.label_size = None  # components
        self.label_dtype = None
        self.has_validation_set = None
        self.mirror_augment = mirror_augment
        self.mirror_augment_v = mirror_augment_v
        self.repeat = repeat
        self.shuffle = shuffle
        self._max_validation = max_validation
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_datasets = dict()
        self._tf_iterator = None
        self._tf_init_ops = dict()
        self._tf_minibatch_np = None
        self._cur_minibatch = -1
        self._cur_lod = -1

        # List files in the dataset directory.
        assert os.path.isdir(self.tfrecord_dir)
        all_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*')))
        self.has_validation_set = (self._max_validation > 0) and any(
            os.path.basename(f).startswith('validation-') for f in all_files)
        all_files = [
            f for f in all_files
            if os.path.basename(f).startswith('validation-') == _is_validation
        ]

        # Inspect tfrecords files.
        tfr_files = [f for f in all_files if f.endswith('.tfrecords')]
        assert len(tfr_files) >= 1
        tfr_shapes = []
        for tfr_file in tfr_files:
            tfr_opt = tf.python_io.TFRecordOptions(
                tf.python_io.TFRecordCompressionType.NONE)
            for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
                if use_raw:
                    tfr_shapes.append(self.parse_tfrecord_np_raw(record))
                else:
                    tfr_shapes.append(self.parse_tfrecord_np(record).shape)
                break

        # Autodetect label filename.
        if self.label_file is None:
            guess = [f for f in all_files if f.endswith('.labels')]
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution.
        max_shape = max(tfr_shapes, key=np.prod)
        self.resolution = resolution if resolution is not None else max_shape[1]
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_shape[0], self.resolution, self.resolution]
        tfr_lods = [
            self.resolution_log2 - int(np.log2(shape[1]))
            for shape in tfr_shapes
        ]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_shapes)
        assert all(shape[1] == self.resolution // (2**lod)
                   for shape, lod in zip(tfr_shapes, tfr_lods))
        #assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))

        # Load labels.
        assert max_label_size == 'full' or max_label_size >= 0
        self._np_labels = np.zeros([1 << 30, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != 'full' and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        if max_images is not None and self._np_labels.shape[0] > max_images:
            self._np_labels = self._np_labels[:max_images]
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device(
                '/cpu:0'), tf.control_dependencies(None):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name='minibatch_in',
                                                   shape=[])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(
                self._np_labels, name='labels_var')
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)
            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes,
                                                    tfr_lods):
                if tfr_lod < 0:
                    continue
                dset = tf.data.TFRecordDataset(tfr_file,
                                               compression_type='',
                                               buffer_size=buffer_mb << 20)
                if max_images is not None:
                    dset = dset.take(max_images)
                if use_raw:
                    dset = dset.map(self.parse_tfrecord_tf_raw,
                                    num_parallel_calls=num_threads)
                else:
                    dset = dset.map(self.parse_tfrecord_tf,
                                    num_parallel_calls=num_threads)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))

                bytes_per_item = np.prod(tfr_shape) * np.dtype(
                    self.dtype).itemsize
                if self.shuffle and shuffle_mb > 0:
                    dset = dset.shuffle((
                        (shuffle_mb << 20) - 1) // bytes_per_item + 1)
                if self.repeat:
                    dset = dset.repeat()
                if prefetch_mb > 0:
                    dset = dset.prefetch((
                        (prefetch_mb << 20) - 1) // bytes_per_item + 1)
                dset = dset.batch(self._tf_minibatch_in)
                self._tf_datasets[tfr_lod] = dset
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_datasets[0].output_types,
                self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {
                lod: self._tf_iterator.make_initializer(dset)
                for lod, dset in self._tf_datasets.items()
            }
示例#6
0
    def __init__(self,
        tfrecord_dir,               # Directory containing a collection of tfrecords files
        resolution      = None,     # Dataset resolution, None = autodetect
        label_file      = None,     # Relative path of the labels file, None = autodetect
        max_label_size  = 0,        # 0 = no labels, "full" = full labels, <int> = N first label components
        max_imgs        = None,     # Maximum number of images to use, None = use all images
        repeat          = True,     # Repeat dataset indefinitely?
        shuffle_mb      = 2048,     # Shuffle data within specified window (megabytes), 0 = disable shuffling
        prefetch_mb     = 512,      # Amount of data to prefetch (megabytes), 0 = disable prefetching
        buffer_mb       = 256,      # Read buffer size (megabytes)
        num_threads     = 4,        # Number of concurrent threads for input processing
        ratio           = 1.0,      # Image height/width ratio in the dataset
        crop_ratio      = None,     # Crop the data according to the ratio
        **kwargs):       

        self.tfrecord_dir       = tfrecord_dir
        self.resolution         = None
        self.resolution_log2    = None
        self.shape              = [] # [channels, height, width]
        self.dtype              = "uint8"
        self.dynamic_range      = [0, 255]
        self.label_file         = label_file
        self.label_size         = None
        self.label_dtype        = None
        self._np_labels         = None
        self._tf_batch_in       = None
        self._tf_labels_var     = None
        self._tf_labels_dataset = None
        self._tf_dataset        = dict()
        self._tf_iterator       = None
        self._tf_init_op        = dict()
        self._tf_batch_np       = None
        self._cur_batch         = -1
        self.ratio              = ratio
        self.crop_ratio         = crop_ratio

        # List tfrecords files and inspect their shapes
        assert os.path.isdir(self.tfrecord_dir)
        tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, "*.tfrecords*"))) # 1of
        # if max_imgs is None:
        # tfr_files = [sorted(glob.glob(re.sub("1of.*", "*", f))) for f in tfr_files]
        # else:
        #     tfr_files = [[f] for f in tfr_files]

        assert len(tfr_files) >= 1
        if max_imgs is None:
            random.shuffle(tfr_files)

        tfr_shapes = []
        for tfr_file in tfr_files:
            tfr_opt = tf.io.TFRecordOptions("")
            for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
                tfr_shapes.append(self.parse_tfrecord_np(record).shape)
                break

        # Autodetect label filename
        if self.label_file is None:
            guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, "*.labels")))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution
        max_resolution = max([tfr_shape[1] for tfr_shape in tfr_shapes])
        self.resolution = resolution or max_resolution
        self.resolution_log2 = int(np.log2(self.resolution))
        file_indexes = [i for i, tfr_shape in enumerate(tfr_shapes) if tfr_shape[1] == self.resolution]
        tfr_files = [tfr_files[i] for i in file_indexes]
        tfr_shapes = [tfr_shapes[i] for i in file_indexes]
        self.shape = [tfr_shapes[0][0], self.resolution, self.resolution]
        self.data_shape = self.shape.copy()
        assert all(shape[0] == self.shape[0] for shape in tfr_shapes)
        assert all(((shape[1] == shape[2]) and (shape[1] == self.resolution)) for shape in tfr_shapes)
        if self.crop_ratio is not None:
            self.shape[1] *= self.crop_ratio

        # Load labels
        assert max_label_size == "full" or max_label_size >= 0
        self._np_labels = np.zeros([1<<20, 0], dtype = np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != "full" and self._np_labels.shape[1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        if max_imgs is not None and self._np_labels.shape[0] > max_imgs:
            self._np_labels = self._np_labels[:max_imgs]
        if max_imgs is not None and self._np_labels.shape[0] < max_imgs:
            print(misc.bcolored("Too many images. increase number.", "red"))
            exit()
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions
        with tf.name_scope("Dataset"), tf.device("/cpu:0"):
            self._tf_batch_in = tf.placeholder(tf.int64, name = "batch_in", shape = [])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name = "labels_var")
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)

            # Load dataset
            dset = tf.data.TFRecordDataset(tfr_files, compression_type = "", buffer_size = buffer_mb << 20, 
                num_parallel_reads = num_threads if max_imgs is None else 1)

            # If max_imgs is set, take a subset of the data
            if max_imgs is not None:
                dset = dset.take(max_imgs)

            # Parse the TF records
            dset = dset.map(self.parse_tfrecord_tf_builder(), num_parallel_calls = num_threads)

            # Zip images with their labels (0s if no labels)
            dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))

            # Shuffle and repeat
            bytes_per_item = np.prod(self.data_shape) * np.dtype(self.dtype).itemsize
            if shuffle_mb > 0:
                dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
            if repeat:
                dset = dset.repeat()

            # Prefetch and batch
            if prefetch_mb > 0:
                dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
            dset = dset.batch(self._tf_batch_in)
            self._tf_dataset = dset

            # Initialize data iterator
            self._tf_iterator = tf.data.Iterator.from_structure(self._tf_dataset.output_types, 
                self._tf_dataset.output_shapes)
            self._tf_init_op = self._tf_iterator.make_initializer(self._tf_dataset)
示例#7
0
    def __init__(
        self,
        tfrecord_dir,  # Directory containing a collection of tfrecords files.
        split='train',  # Dataset split, 'train' or 'test'
        from_tfrecords=False,  # Load from tfrecords or from tensorflow datasets
        tfds_data_dir=None,  # Directory from which tensorflow datasets load
        resolution=None,  # Dataset resolution, None = autodetect.
        label_file=None,  # Relative path of the labels file, None = autodetect.
        max_label_size=0,  # 0 = no labels, 'full' = full labels, <int> = N first label components.
        num_samples=None,  # Maximum number of images to use, None = use all images.
        num_val_images=None,  # Number of validation images split from the training set, None = use separate validation set.
        repeat=True,  # Repeat dataset indefinitely?
        shuffle_mb=4096,  # Shuffle data within specified window (megabytes), 0 = disable shuffling.
        prefetch_mb=2048,  # Amount of data to prefetch (megabytes), 0 = disable prefetching.
        buffer_mb=256,  # Read buffer size (megabytes).
        num_threads=2):  # Number of concurrent threads.

        self.tfrecord_dir = tfrecord_dir
        self.resolution = None
        self.shape = []  # [channels, height, width]
        self.dtype = 'uint8'
        self.dynamic_range = [0, 255]
        self.from_tfrecords = from_tfrecords
        self.label_file = label_file
        self.label_size = None  # components
        self.label_dtype = None
        self.num_samples = num_samples if split == 'train' else None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_dataset = None
        self._tf_iterator = None
        self._tf_init_op = None
        self._tf_minibatch_np = None
        self._cur_minibatch = -1

        # List tfrecords files and inspect their shapes.
        if self.from_tfrecords:
            self.name = os.path.basename(self.tfrecord_dir)
            if resolution is not None:
                self.name += '-{}'.format(resolution)
            data_dir = self.tfrecord_dir + '-val' if num_val_images is None and split == 'test' else self.tfrecord_dir
            tfr_files = sorted(glob.glob(os.path.join(data_dir,
                                                      '*.tfrecords')))
            assert len(tfr_files) >= 1
            tfr_shapes = []
            for tfr_file in tfr_files:
                tfr_opt = tf.python_io.TFRecordOptions(
                    tf.python_io.TFRecordCompressionType.NONE)
                for record in tf.python_io.tf_record_iterator(
                        tfr_file, tfr_opt):
                    tfr_shapes.append(self.parse_tfrecord_np(record).shape)
                    break

            # Autodetect label filename.
            if self.label_file is None:
                guess = sorted(glob.glob(os.path.join(data_dir, '*.labels')))
                if len(guess):
                    self.label_file = guess[0]
            elif not os.path.isfile(self.label_file):
                guess = os.path.join(data_dir, self.label_file)
                if os.path.isfile(guess):
                    self.label_file = guess

            # Determine shape and resolution.
            max_shape = max(tfr_shapes, key=np.prod)
            tfr_file = [
                tfr_file for tfr_shape, tfr_file in zip(tfr_shapes, tfr_files)
                if tfr_shape == max_shape
            ][0]
            tfr_shape = max_shape
            assert tfr_shape[1] == tfr_shape[2]

            dset = tf.data.TFRecordDataset(tfr_file,
                                           compression_type='',
                                           buffer_size=buffer_mb << 20)

            self._np_labels = np.zeros([1 << 30], dtype=np.int32)
            if self.label_file is not None and max_label_size != 0:
                self._np_labels = np.load(self.label_file).astype(np.int32)
                self.label_size = self._np_labels.max() + 1
                assert self._np_labels.ndim == 1
                assert np.unique(self._np_labels).shape[0] == self.label_size
            else:
                self.label_size = 0

            if num_val_images is not None:
                if split == 'test':
                    dset = dset.take(num_val_images)
                    self._np_labels = self._np_labels[:num_val_images]
                else:
                    dset = dset.skip(num_val_images)
                    self._np_labels = self._np_labels[num_val_images:]
            if self.num_samples is not None and self._np_labels.shape[
                    0] > self.num_samples:
                self._np_labels = self._np_labels[:self.num_samples]
            self.num_samples = self._np_labels.shape[0]
        else:
            self.name = self.tfrecord_dir
            dset, info = tfds.load(name=self.name,
                                   data_dir=tfds_data_dir,
                                   split=split,
                                   with_info=True)
            if max_label_size != 0:
                self.label_size = info.features['label'].num_classes
            else:
                self.label_size = 0
            if self.num_samples is None:
                self.num_samples = info.splits[split].num_examples
            tfr_shape = [
                int(dset.output_shapes['image'][d]) for d in [2, 0, 1]
            ]

        self.resolution = max(tfr_shape[1], tfr_shape[2])
        if resolution is not None and resolution != self.resolution:
            self.resolution = resolution
            resize = True
        else:
            resize = False
        self.resolution_log2 = int(np.ceil(np.log2(self.resolution)))
        self.shape = [tfr_shape[0], self.resolution, self.resolution]

        # Build TF expressions.
        with tf.name_scope('Dataset'), tf.device('/cpu:0'):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name='minibatch_in',
                                                   shape=[])
            if num_samples is not None:
                dset = dset.take(self.num_samples)
            if self.from_tfrecords:
                dset = dset.map(functools.partial(self.parse_tfrecord_tf,
                                                  resize=resize),
                                num_parallel_calls=num_threads)
                self._tf_labels_var = tflib.create_var_with_large_initial_value(
                    self._np_labels, name='labels_var')
                self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                    self._tf_labels_var)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
            else:
                dset = dset.map(functools.partial(self.parse_tfdataset_tf,
                                                  resize=resize),
                                num_parallel_calls=num_threads)
            bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
            if shuffle_mb > 0:
                dset = dset.shuffle((
                    (shuffle_mb << 20) - 1) // bytes_per_item + 1)
            if repeat:
                dset = dset.repeat()
            if prefetch_mb > 0:
                dset = dset.prefetch((
                    (prefetch_mb << 20) - 1) // bytes_per_item + 1)
            dset = dset.batch(self._tf_minibatch_in)
            self._tf_dataset = dset

            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_dataset.output_types, self._tf_dataset.output_shapes)
            self._tf_init_op = self._tf_iterator.make_initializer(
                self._tf_dataset)